Skip to content

Commit 83dfbc8

Browse files
lengrongfuepwalsh
authored andcommitted
[Bugfix] fix when skip tokenizer init (vllm-project#21922)
Signed-off-by: rongfu.leng <[email protected]>
1 parent cc2a1a5 commit 83dfbc8

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

tests/v1/engine/test_llm_engine.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,29 @@ def find_metric(name) -> list[Metric]:
213213
assert len(num_accepted_tokens_per_pos) == 1
214214
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
215215
assert len(num_accepted_tokens_per_pos[0].values) == 5
216+
217+
218+
@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
219+
def test_skip_tokenizer_initialization(model: str,
220+
monkeypatch: pytest.MonkeyPatch):
221+
monkeypatch.setenv("VLLM_USE_V1", "1")
222+
# This test checks if the flag skip_tokenizer_init skips the initialization
223+
# of tokenizer and detokenizer. The generated output is expected to contain
224+
# token ids.
225+
llm = LLM(
226+
model=model,
227+
skip_tokenizer_init=True,
228+
enforce_eager=True,
229+
)
230+
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
231+
232+
with pytest.raises(ValueError, match="cannot pass text prompts when"):
233+
llm.generate("abc", sampling_params)
234+
235+
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
236+
sampling_params=sampling_params)
237+
assert len(outputs) > 0
238+
completions = outputs[0].outputs
239+
assert len(completions) > 0
240+
assert completions[0].text == ""
241+
assert completions[0].token_ids

vllm/v1/engine/processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def _validate_sampling_params(
8989
return
9090
if not params.allowed_token_ids:
9191
raise ValueError("allowed_token_ids is not None and empty!")
92+
if self.tokenizer is None:
93+
# When skip_tokenizer_init=True, we can't validate token IDs
94+
# Skip validation and let the model handle invalid tokens
95+
return
9296
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
9397
vocab_size = len(tokenizer)
9498
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
@@ -283,8 +287,9 @@ def process_inputs(
283287
len(decoder_inputs["prompt_token_ids"]))
284288
sampling_params.update_from_generation_config(
285289
self.generation_config_fields, eos_token_id)
286-
sampling_params.update_from_tokenizer(
287-
self.tokenizer.get_lora_tokenizer(lora_request))
290+
if self.tokenizer is not None:
291+
sampling_params.update_from_tokenizer(
292+
self.tokenizer.get_lora_tokenizer(lora_request))
288293
else:
289294
pooling_params = params.clone()
290295

0 commit comments

Comments
 (0)