Skip to content

Commit ad0b325

Browse files
committed
New version
1 parent 414c3ba commit ad0b325

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ async def create_chat_completion(
248248
sampling_params = request.to_sampling_params()
249249
if request.enforced_str:
250250
toks = self.tokenizer(request.enforced_str, add_special_tokens=False)
251-
sampling_params.enforce_token_ids = toks.input_ids
251+
sampling_params.enforce_token_ids = toks.input_ids + [self.tokenizer.eos_token_id]
252252
lora_request = self._maybe_get_lora(request)
253253
decoding_config = await self.engine.get_decoding_config()
254254
guided_decoding_backend = request.guided_decoding_backend \

vllm/model_executor/layers/sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _enforced_sample(
417417
) -> SampleResultType:
418418
results: SampleResultType = []
419419
for next_token_id in enforced_token_ids:
420-
results.append(([next_token_id, next_token_id], [0, 0]))
420+
results.append(([next_token_id], [0]))
421421

422422
return results
423423

@@ -607,8 +607,10 @@ def _sample_with_torch(
607607
enforced_token_ids = []
608608
for seq_group in seq_groups:
609609
sampling_params = seq_group.sampling_params
610+
first_seq_id = seq_group.seq_ids[0]
611+
output_token_ids = seq_group.seq_data[first_seq_id].output_token_ids
610612
enforced_token_ids.append(
611-
sampling_params.enforce_token_ids[len(seq_group.seq_data[seq_group.seq_ids[0]].output_token_ids)]
613+
sampling_params.enforce_token_ids[len(output_token_ids)]
612614
)
613615

614616
if sampled_token_ids_tensor is not None:

0 commit comments

Comments
 (0)