File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -248,7 +248,7 @@ async def create_chat_completion(
248248 sampling_params = request .to_sampling_params ()
249249 if request .enforced_str :
250250 toks = self .tokenizer (request .enforced_str , add_special_tokens = False )
251- sampling_params .enforce_token_ids = toks .input_ids
251+ sampling_params .enforce_token_ids = toks .input_ids + [ self . tokenizer . eos_token_id ]
252252 lora_request = self ._maybe_get_lora (request )
253253 decoding_config = await self .engine .get_decoding_config ()
254254 guided_decoding_backend = request .guided_decoding_backend \
Original file line number Diff line number Diff line change @@ -417,7 +417,7 @@ def _enforced_sample(
417417) -> SampleResultType :
418418 results : SampleResultType = []
419419 for next_token_id in enforced_token_ids :
420- results .append (([next_token_id , next_token_id ], [0 , 0 ]))
420+ results .append (([next_token_id ], [0 ]))
421421
422422 return results
423423
@@ -607,8 +607,10 @@ def _sample_with_torch(
607607 enforced_token_ids = []
608608 for seq_group in seq_groups :
609609 sampling_params = seq_group .sampling_params
610+ first_seq_id = seq_group .seq_ids [0 ]
611+ output_token_ids = seq_group .seq_data [first_seq_id ].output_token_ids
610612 enforced_token_ids .append (
611- sampling_params .enforce_token_ids [len (seq_group . seq_data [ seq_group . seq_ids [ 0 ]]. output_token_ids )]
613+ sampling_params .enforce_token_ids [len (output_token_ids )]
612614 )
613615
614616 if sampled_token_ids_tensor is not None :
You can’t perform that action at this time.
0 commit comments