Skip to content

Commit d9e00db

Browse files
authored
[Performance] V1 Classify Models E2E Performance Optimization (vllm-project#23541)
Signed-off-by: wang.yuqi <[email protected]>
1 parent ad39106 commit d9e00db

File tree

5 files changed

+80
-37
lines changed

5 files changed

+80
-37
lines changed

tests/entrypoints/llm/test_classify.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ def test_encode_api(llm: LLM):
6262
err_msg = "pooling_task must be one of.+"
6363
with pytest.raises(ValueError, match=err_msg):
6464
llm.encode(prompts, use_tqdm=False)
65+
66+
67+
def test_score_api(llm: LLM):
68+
err_msg = "Score API is only enabled for num_labels == 1."
69+
with pytest.raises(ValueError, match=err_msg):
70+
llm.score("ping", "pong", use_tqdm=False)

tests/entrypoints/openai/test_classification.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,33 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str):
226226
},
227227
)
228228
assert response.json()["error"]["type"] == "BadRequestError"
229+
230+
231+
@pytest.mark.asyncio
232+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
233+
def test_score(server: RemoteOpenAIServer, model_name: str):
234+
# score api is only enabled for num_labels == 1.
235+
response = requests.post(
236+
server.url_for("score"),
237+
json={
238+
"model": model_name,
239+
"text_1": "ping",
240+
"text_2": "pong",
241+
},
242+
)
243+
assert response.json()["error"]["type"] == "BadRequestError"
244+
245+
246+
@pytest.mark.asyncio
247+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
248+
def test_rerank(server: RemoteOpenAIServer, model_name: str):
249+
# rerank api is only enabled for num_labels == 1.
250+
response = requests.post(
251+
server.url_for("rerank"),
252+
json={
253+
"model": model_name,
254+
"query": "ping",
255+
"documents": ["pong"],
256+
},
257+
)
258+
assert response.json()["error"]["type"] == "BadRequestError"

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,17 +1805,13 @@ async def init_app_state(
18051805
request_logger=request_logger,
18061806
log_error_stack=args.log_error_stack,
18071807
) if "classify" in supported_tasks else None
1808-
1809-
enable_serving_reranking = ("classify" in supported_tasks and getattr(
1810-
model_config.hf_config, "num_labels", 0) == 1)
18111808
state.openai_serving_scores = ServingScores(
18121809
engine_client,
18131810
model_config,
18141811
state.openai_serving_models,
18151812
request_logger=request_logger,
18161813
log_error_stack=args.log_error_stack,
1817-
) if ("embed" in supported_tasks or enable_serving_reranking) else None
1818-
1814+
) if ("embed" in supported_tasks or "score" in supported_tasks) else None
18191815
state.openai_serving_tokenization = OpenAIServingTokenization(
18201816
engine_client,
18211817
model_config,

vllm/model_executor/layers/pooler.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
from transformers import PretrainedConfig
1414

1515
from vllm.config import ModelConfig, PoolerConfig
16+
from vllm.logger import init_logger
1617
from vllm.pooling_params import PoolingParams
1718
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
1819
from vllm.tasks import PoolingTask
1920
from vllm.utils import current_stream, resolve_obj_by_qualname
2021
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
2122

23+
logger = init_logger(__name__)
24+
2225
PoolingFn = Callable[
2326
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
2427
Union[torch.Tensor, list[torch.Tensor]]]
@@ -183,7 +186,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
183186
fn = resolve_obj_by_qualname(function_name)()
184187
return PoolerActivation.wraps(fn)
185188

186-
return PoolerScore()
189+
return PoolerClassify()
187190

188191

189192
def build_output(
@@ -371,22 +374,29 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
371374

372375
class PoolerClassify(PoolerActivation):
373376

374-
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
375-
num_labels = pooled_data.shape[-1]
376-
if num_labels < 2:
377-
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
378-
379-
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
380-
377+
def __init__(self, *, static_num_labels: bool = True) -> None:
378+
super().__init__()
381379

382-
class PoolerScore(PoolerActivation):
380+
if static_num_labels:
381+
from vllm.config import get_current_vllm_config
382+
vllm_config = get_current_vllm_config()
383+
self.num_labels = getattr(vllm_config.model_config.hf_config,
384+
"num_labels", 0)
385+
if self.num_labels == 0:
386+
logger.warning("num_labels should be > 0 for classification"
387+
"models, falling back to softmax. "
388+
"Please check if the configuration is correct.")
389+
else:
390+
self.num_labels = None
383391

384392
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
385-
num_labels = pooled_data.shape[-1]
393+
num_labels = (self.num_labels if self.num_labels is not None else
394+
pooled_data.shape[-1])
395+
386396
if num_labels < 2:
387397
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
388398

389-
return pooled_data
399+
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
390400

391401

392402
class LambdaPoolerActivation(PoolerActivation):
@@ -428,6 +438,10 @@ def __init__(self) -> None:
428438
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
429439
pooling_metadata: PoolingMetadata):
430440

441+
if isinstance(pooled_data, list):
442+
pooled_data = torch.stack(pooled_data)
443+
# pooled_data shape: [batchsize, hidden_dimension]
444+
431445
# Apply ST projector
432446
if self.projector is not None:
433447
projector = cast(nn.Module, self.projector)
@@ -437,17 +451,11 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
437451
y = projector(x.to(torch.float32))
438452
return y.to(orig_dtype)
439453

440-
if isinstance(pooled_data, torch.Tensor):
441-
pooled_data = _proj(pooled_data)
442-
else:
443-
pooled_data = [_proj(t) for t in pooled_data]
454+
pooled_data = _proj(pooled_data)
455+
# pooled_data shape: [batchsize, embedding_dimension]
444456

445457
pooling_params = get_pooling_params(pooling_metadata)
446458

447-
if isinstance(pooled_data, list):
448-
pooled_data = torch.stack(pooled_data)
449-
# pooled_data shape: [batchsize, embedding_dimension]
450-
451459
# for matryoshka representation
452460
dimensions_list = [
453461
pooling_param.dimensions for pooling_param in pooling_params
@@ -477,13 +485,14 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
477485
for vecs, f in zip(pooled_data, flags)
478486
]
479487

488+
# pooled_data shape: [batchsize, embedding_dimension]
480489
return pooled_data
481490

482491

483492
class RewardPoolerHead(PoolerHead):
484493

485494
def __init__(self) -> None:
486-
super().__init__(activation=PoolerClassify())
495+
super().__init__(activation=PoolerClassify(static_num_labels=False))
487496

488497
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
489498
pooling_metadata: PoolingMetadata):
@@ -637,19 +646,13 @@ def forward(
637646
pooling_metadata: PoolingMetadata,
638647
) -> PoolerOutput:
639648
pooled_data = self.pooling(hidden_states, pooling_metadata)
640-
641649
if isinstance(pooled_data, list):
642650
pooled_data = torch.stack(pooled_data)
643651
# pooled_data shape: [batchsize, hidden_size]
644652

645653
if self.classifier is not None:
646-
# apply classifier once on the full batch if possible
647-
if isinstance(pooled_data, torch.Tensor):
648-
pooled_data = self.classifier(pooled_data)
649-
elif len({data.shape for data in pooled_data}) <= 1:
650-
pooled_data = self.classifier(torch.stack(pooled_data))
651-
else:
652-
pooled_data = [self.classifier(data) for data in pooled_data]
654+
pooled_data = self.classifier(pooled_data)
655+
# pooled_data shape: [batchsize, num_labels]
653656

654657
pooling_params = get_pooling_params(pooling_metadata)
655658
flags = [p.activation for p in pooling_params]
@@ -662,6 +665,7 @@ def forward(
662665
for vecs, f in zip(pooled_data, flags)
663666
]
664667

668+
# scores shape: [batchsize, num_labels]
665669
return build_output(scores)
666670

667671

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,10 +1248,17 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]:
12481248
and "encode" in supported_tasks):
12491249
supported_tasks.remove("encode")
12501250

1251-
logger.info_once("Chunked prefill is not supported with "
1252-
"encode task which using ALL pooling. "
1253-
"Please turn off chunked prefill by "
1254-
"`--no-enable-chunked-prefill` before using it.")
1251+
logger.debug_once("Chunked prefill is not supported with "
1252+
"encode task which using ALL pooling. "
1253+
"Please turn off chunked prefill by "
1254+
"`--no-enable-chunked-prefill` before using it.")
1255+
1256+
if "score" in supported_tasks:
1257+
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
1258+
if num_labels != 1:
1259+
supported_tasks.remove("score")
1260+
logger.debug_once(
1261+
"Score API is only enabled for num_labels == 1.")
12551262

12561263
return supported_tasks
12571264

0 commit comments

Comments
 (0)