@@ -602,6 +602,8 @@ def inference_stream(model: PreTrainedModel,
602602 print_idx = 0
603603 if not is_observation :
604604 history .append (None ) # dummy
605+ # Avoid the occurrence of repeated words in sentence.
606+ first_num_space = - 1
605607 for token in streamer :
606608 raw_generate_ids .append (token )
607609 generate_ids = template .get_generate_ids (
@@ -612,6 +614,13 @@ def inference_stream(model: PreTrainedModel,
612614 if isinstance (template .suffix [- 1 ], list ):
613615 generate_ids = generate_ids [:- len (template .suffix [- 1 ])]
614616 response = tokenizer .decode (generate_ids , ** tokenizer_kwargs )
617+ cur_num_space = len (response ) - len (response .lstrip (' ' ))
618+ if first_num_space == - 1 :
619+ first_num_space = cur_num_space
620+ if cur_num_space < first_num_space :
621+ response = ' ' * (first_num_space - cur_num_space ) + response
622+ elif cur_num_space > first_num_space :
623+ response = response [cur_num_space - first_num_space :]
615624 if isinstance (template .suffix [- 1 ], str ):
616625 response = response [:- len (template .suffix [- 1 ])]
617626 print_idx = _get_safe_print_idx (response , print_idx )
@@ -628,6 +637,12 @@ def inference_stream(model: PreTrainedModel,
628637 generate_ids [- len (template .suffix [- 1 ]):] == template .suffix [- 1 ]):
629638 generate_ids = generate_ids [:- len (template .suffix [- 1 ])]
630639 response = tokenizer .decode (generate_ids , ** tokenizer_kwargs )
640+ if first_num_space > - 1 :
641+ cur_num_space = len (response ) - len (response .lstrip (' ' ))
642+ if cur_num_space < first_num_space :
643+ response = ' ' * (first_num_space - cur_num_space ) + response
644+ elif cur_num_space > first_num_space :
645+ response = response [cur_num_space - first_num_space :]
631646 if isinstance (
632647 template .suffix [- 1 ], str
633648 ) and response [- len (template .suffix [- 1 ]):] == template .suffix [- 1 ]:
0 commit comments