Skip to content

Commit f39375d

Browse files
maxdebayserjingyu
authored andcommitted
Support token_type_ids in V1 with less code changes (vllm-project#21985)
Signed-off-by: Max de Bayser <[email protected]> Signed-off-by: jingyu <[email protected]>
1 parent 68b4a38 commit f39375d

File tree

10 files changed

+235
-130
lines changed

10 files changed

+235
-130
lines changed

tests/entrypoints/openai/test_rerank.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
126126
invocation_output["results"]):
127127
assert rerank_result.keys() == invocations_result.keys()
128128
assert rerank_result["relevance_score"] == pytest.approx(
129-
invocations_result["relevance_score"], rel=0.01)
129+
invocations_result["relevance_score"], rel=0.05)
130+
# TODO: reset this tolerance to 0.01 once we find
131+
# an alternative to flash_attn with bfloat16
130132

131133

132134
@pytest.mark.asyncio

tests/entrypoints/openai/test_score.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
220220
invocation_output["data"]):
221221
assert score_data.keys() == invocation_data.keys()
222222
assert score_data["score"] == pytest.approx(
223-
invocation_data["score"], rel=0.01)
223+
invocation_data["score"], rel=0.05)
224+
# TODO: reset this tolerance to 0.01 once we find
225+
# an alternative to flash_attn with bfloat16
224226

225227
def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
226228
Any]):

tests/models/language/pooling/test_scoring.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
"The capital of Germany is Berlin.",
2424
]
2525

26+
27+
@pytest.fixture(autouse=True)
28+
def v1(run_with_both_engines):
29+
# Simple autouse wrapper to run both engines for each test
30+
# This can be promoted up to conftest.py to run for every
31+
# test in a package
32+
pass
33+
34+
2635
DTYPE = "half"
2736

2837

vllm/entrypoints/llm.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@
2828
apply_mistral_chat_template,
2929
parse_chat_messages,
3030
resolve_chat_template_content_format)
31+
# yapf conflicts with isort for this block
32+
# yapf: disable
3133
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
3234
ScoreMultiModalParam,
3335
_cosine_similarity,
3436
_validate_score_input_lens,
37+
compress_token_type_ids,
3538
get_score_prompt)
39+
# yapf: enable
3640
from vllm.entrypoints.utils import (_validate_truncation_size,
3741
log_non_default_args)
3842
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
@@ -1329,6 +1333,7 @@ def _cross_encoding_score(
13291333

13301334
model_config = self.llm_engine.model_config
13311335
pooling_params.verify("score", model_config)
1336+
pooling_params_list = list[PoolingParams]()
13321337

13331338
tokenization_kwargs: dict[str, Any] = {}
13341339

@@ -1339,38 +1344,31 @@ def _cross_encoding_score(
13391344

13401345
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
13411346

1342-
if model_config.is_multimodal_model:
1343-
for q, d in input_pairs:
1344-
_, engine_prompt = get_score_prompt(
1345-
model_config=model_config,
1346-
data_1=q,
1347-
data_2=d,
1348-
tokenizer=tokenizer,
1349-
tokenization_kwargs=tokenization_kwargs,
1350-
)
1347+
model_config = self.llm_engine.model_config
13511348

1352-
parsed_prompts.append(engine_prompt)
1353-
else:
1354-
for q, t in input_pairs:
1355-
if model_config.use_pad_token:
1356-
# cross_encoder models defaults to using pad_token.
1357-
prompt_inputs = tokenizer(
1358-
text=q, # type: ignore[arg-type]
1359-
text_pair=t, # type: ignore[arg-type]
1360-
**tokenization_kwargs)
1361-
else:
1362-
# `llm as reranker` models defaults to not using pad_token.
1363-
prompt_inputs = tokenizer(
1364-
text=q + t, # type: ignore[operator]
1365-
**tokenization_kwargs)
1366-
engine_prompt = TokensPrompt(
1367-
prompt_token_ids=prompt_inputs["input_ids"],
1368-
token_type_ids=prompt_inputs.get("token_type_ids"))
1369-
parsed_prompts.append(engine_prompt)
1349+
for q, d in input_pairs:
1350+
_, engine_prompt = get_score_prompt(
1351+
model_config=model_config,
1352+
data_1=q,
1353+
data_2=d,
1354+
tokenizer=tokenizer,
1355+
tokenization_kwargs=tokenization_kwargs,
1356+
)
1357+
1358+
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
1359+
"token_type_ids", None)):
1360+
params = pooling_params.clone()
1361+
compressed = compress_token_type_ids(token_type_ids)
1362+
params.extra_kwargs = {"compressed_token_type_ids": compressed}
1363+
pooling_params_list.append(params)
1364+
else:
1365+
pooling_params_list.append(pooling_params)
1366+
1367+
parsed_prompts.append(engine_prompt)
13701368

13711369
self._validate_and_add_requests(
13721370
prompts=parsed_prompts,
1373-
params=pooling_params,
1371+
params=pooling_params_list,
13741372
use_tqdm=use_tqdm,
13751373
lora_request=lora_request,
13761374
)

vllm/entrypoints/openai/serving_score.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from fastapi import Request
99

10+
from vllm import envs
1011
from vllm.config import ModelConfig
1112
from vllm.engine.protocol import EngineClient
1213
from vllm.entrypoints.logger import RequestLogger
@@ -17,11 +18,15 @@
1718
ScoreResponseData, UsageInfo)
1819
from vllm.entrypoints.openai.serving_engine import OpenAIServing
1920
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
21+
# yapf conflicts with isort for this block
22+
# yapf: disable
2023
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
2124
ScoreMultiModalParam,
2225
_cosine_similarity,
2326
_validate_score_input_lens,
27+
compress_token_type_ids,
2428
get_score_prompt)
29+
# yapf: enable
2530
from vllm.entrypoints.utils import _validate_truncation_size
2631
from vllm.inputs.data import TokensPrompt
2732
from vllm.logger import init_logger
@@ -158,6 +163,8 @@ def _preprocess_score(
158163
tokenizer=tokenizer,
159164
tokenization_kwargs=tokenization_kwargs,
160165
)
166+
self._validate_input(request, engine_prompt["prompt_token_ids"],
167+
full_prompt)
161168
if request.mm_processor_kwargs is not None:
162169
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
163170

@@ -188,64 +195,27 @@ async def _cross_encoding_score(
188195

189196
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
190197

191-
if self.model_config.is_multimodal_model:
198+
preprocess_async = make_async(self._preprocess_score,
199+
executor=self._tokenizer_executor)
192200

193-
preprocess_async = make_async(self._preprocess_score,
194-
executor=self._tokenizer_executor)
201+
preprocessed_prompts = await asyncio.gather(
202+
*(preprocess_async(request=request,
203+
tokenizer=tokenizer,
204+
tokenization_kwargs=tokenization_kwargs,
205+
data_1=t1,
206+
data_2=t2) for t1, t2 in input_pairs))
195207

196-
preprocessed_prompts = await asyncio.gather(
197-
*(preprocess_async(request=request,
198-
tokenizer=tokenizer,
199-
tokenization_kwargs=tokenization_kwargs,
200-
data_1=t1,
201-
data_2=t2) for t1, t2 in input_pairs))
202-
203-
for full_prompt, engine_prompt in preprocessed_prompts:
204-
request_prompts.append(full_prompt)
205-
engine_prompts.append(engine_prompt)
206-
207-
else:
208-
tokenize_async = make_async(tokenizer.__call__,
209-
executor=self._tokenizer_executor)
210-
use_pad_token = self.model_config.use_pad_token
211-
212-
if use_pad_token:
213-
# cross_encoder models defaults to using pad_token.
214-
tokenized_prompts = await asyncio.gather(*(
215-
tokenize_async(
216-
text=t1, # type: ignore[arg-type]
217-
text_pair=t2, # type: ignore[arg-type]
218-
**tokenization_kwargs) for t1, t2 in input_pairs))
219-
else:
220-
# `llm as reranker` models defaults to not using pad_token.
221-
tokenized_prompts = await asyncio.gather(*(
222-
tokenize_async(
223-
text=t1 + # type: ignore[operator]
224-
t2,
225-
**tokenization_kwargs) for t1, t2 in input_pairs))
226-
227-
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
228-
sep_token = tokenizer.sep_token if (tokenizer.sep_token
229-
and use_pad_token) else ''
230-
request_prompt = f"{t1}{sep_token}{t2}"
231-
232-
input_ids = prompt_inputs["input_ids"]
233-
text_token_prompt = \
234-
self._validate_input(request, input_ids, request_prompt)
235-
engine_prompt = TokensPrompt(
236-
prompt_token_ids=text_token_prompt["prompt_token_ids"],
237-
token_type_ids=prompt_inputs.get("token_type_ids"))
238-
239-
request_prompts.append(request_prompt)
240-
engine_prompts.append(engine_prompt)
208+
for full_prompt, engine_prompt in preprocessed_prompts:
209+
request_prompts.append(full_prompt)
210+
engine_prompts.append(engine_prompt)
241211

242212
# Schedule the request and get the result generator.
243213
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
244214

245-
pooling_params = request.to_pooling_params()
215+
default_pooling_params = request.to_pooling_params()
246216

247217
try:
248-
pooling_params.verify("score", self.model_config)
218+
default_pooling_params.verify("score", self.model_config)
249219
except ValueError as e:
250220
return self.create_error_response(str(e))
251221

@@ -254,9 +224,19 @@ async def _cross_encoding_score(
254224

255225
self._log_inputs(request_id_item,
256226
request_prompts[i],
257-
params=pooling_params,
227+
params=default_pooling_params,
258228
lora_request=lora_request)
259229

230+
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
231+
"token_type_ids", None)):
232+
pooling_params = default_pooling_params.clone()
233+
compressed = compress_token_type_ids(token_type_ids)
234+
pooling_params.extra_kwargs = {
235+
"compressed_token_type_ids": compressed
236+
}
237+
else:
238+
pooling_params = (default_pooling_params)
239+
260240
generator = self.engine_client.encode(
261241
engine_prompt,
262242
pooling_params,

vllm/entrypoints/score_utils.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,49 @@ def get_score_prompt(
184184
model_config,
185185
tokenizer,
186186
)
187+
from vllm.model_executor.model_loader import get_model_cls
187188

188-
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
189-
190-
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
189+
model = get_model_cls(model_config)
190+
if supports_score_template(model):
191+
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
192+
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
193+
elif model_config.use_pad_token:
194+
# cross_encoder models defaults to using pad_token.
195+
prompt_inputs = tokenizer(text=prompt_1,
196+
text_pair=prompt_2,
197+
**tokenization_kwargs)
198+
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
199+
else:
200+
# `llm as reranker` models defaults to not using pad_token.
201+
full_prompt = prompt_1 + prompt_2
202+
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
191203

192204
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
193205

206+
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
207+
engine_prompt["token_type_ids"] = token_type_ids
208+
194209
post_process_tokens(model_config, engine_prompt)
195210

196211
if mm_data is not None:
197212
engine_prompt["multi_modal_data"] = mm_data
198213
return full_prompt, engine_prompt
214+
215+
216+
def compress_token_type_ids(token_type_ids: list[int]) -> int:
217+
"""
218+
Return position of the first 1 or the length of the list
219+
if not found.
220+
"""
221+
first_one = len(token_type_ids)
222+
err_msg = "Token type ids are expected to be a sequence"\
223+
" of zeros followed by a sequence of ones"
224+
for i, type_id in enumerate(token_type_ids):
225+
if type_id == 0 and first_one < i:
226+
raise ValueError(err_msg)
227+
elif type_id == 1 and first_one > i:
228+
first_one = i
229+
elif type_id > 1:
230+
raise ValueError(err_msg)
231+
232+
return first_one

0 commit comments

Comments
 (0)