Skip to content

Commit b886406

Browse files
Jintao-Huangtastelikefeet
authored andcommitted
Fix stream 0415 (#702)
(cherry picked from commit 36763c0)
1 parent cdf4e51 commit b886406

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

swift/llm/utils/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ def inference_stream(model: PreTrainedModel,
602602
print_idx = 0
603603
if not is_observation:
604604
history.append(None) # dummy
605+
# Avoid the occurrence of repeated words in sentence.
606+
first_num_space = -1
605607
for token in streamer:
606608
raw_generate_ids.append(token)
607609
generate_ids = template.get_generate_ids(
@@ -612,6 +614,13 @@ def inference_stream(model: PreTrainedModel,
612614
if isinstance(template.suffix[-1], list):
613615
generate_ids = generate_ids[:-len(template.suffix[-1])]
614616
response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
617+
cur_num_space = len(response) - len(response.lstrip(' '))
618+
if first_num_space == -1:
619+
first_num_space = cur_num_space
620+
if cur_num_space < first_num_space:
621+
response = ' ' * (first_num_space - cur_num_space) + response
622+
elif cur_num_space > first_num_space:
623+
response = response[cur_num_space - first_num_space:]
615624
if isinstance(template.suffix[-1], str):
616625
response = response[:-len(template.suffix[-1])]
617626
print_idx = _get_safe_print_idx(response, print_idx)
@@ -628,6 +637,12 @@ def inference_stream(model: PreTrainedModel,
628637
generate_ids[-len(template.suffix[-1]):] == template.suffix[-1]):
629638
generate_ids = generate_ids[:-len(template.suffix[-1])]
630639
response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
640+
if first_num_space > -1:
641+
cur_num_space = len(response) - len(response.lstrip(' '))
642+
if cur_num_space < first_num_space:
643+
response = ' ' * (first_num_space - cur_num_space) + response
644+
elif cur_num_space > first_num_space:
645+
response = response[cur_num_space - first_num_space:]
631646
if isinstance(
632647
template.suffix[-1], str
633648
) and response[-len(template.suffix[-1]):] == template.suffix[-1]:

0 commit comments

Comments
 (0)