@@ -684,12 +684,18 @@ def greedy_until_multi_turn( # noqa: C901
684684 ** model_inputs , stopping_criteria = stopping_criteria , ** generation_config
685685 )
686686 model_outputs = model_outputs .sequences [0 , model_inputs ["input_ids" ].size (1 ) :]
687+
688+ # We manage stop tokens in an extra step in case they were incorrectly detected earlier
689+ # (which can happen for multitoken stop sequences)
690+ decoded_generation = self .tokenizer .decode (model_outputs ) # should we skip_special_tokens=True here?
691+ for term in stop_tokens :
692+ decoded_generation = decoded_generation .split (term )[0 ]
687693 model_generations = [model_outputs ]
688694
689695 input_tokens = [model_inputs ["input_ids" ]]
690696
691697 for i , multi_turn_context in enumerate (request .context [1 :]):
692- multi_turn_context = multi_turn_context .format (model_response = model_generations [ - 1 ] )
698+ multi_turn_context = multi_turn_context .format (model_response = decoded_generation )
693699
694700 model_inputs = self .tokenizer (
695701 multi_turn_context ,
@@ -731,9 +737,9 @@ def greedy_until_multi_turn( # noqa: C901
731737 )
732738 model_outputs = model_outputs .sequences [0 , model_inputs ["input_ids" ].size (1 ) :]
733739 model_generations .append (model_outputs )
734- decoded_generation = self .tokenizer .decode (model_outputs , skip_special_tokens = True )
735740 input_tokens .append (model_inputs ["input_ids" ])
736741
742+ decoded_generation = self .tokenizer .decode (model_outputs , skip_special_tokens = True )
737743 for term in stop_tokens :
738744 decoded_generation = decoded_generation .split (term )[0 ]
739745
0 commit comments