Skip to content

Commit 8d4a925

Browse files
authored
fix bugs (#4031)
1 parent ad90779 commit 8d4a925

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

swift/llm/argument/base_args/template_args.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,15 @@ class TemplateArguments:
4848
def __post_init__(self):
4949
if self.template is None and hasattr(self, 'model_meta'):
5050
self.template = self.model_meta.template
51-
if self.system is not None and self.system.endswith('.txt'):
52-
assert os.path.isfile(self.system), f'self.system: {self.system}'
53-
with open(self.system, 'r') as f:
54-
self.system = f.read()
51+
if self.system is not None:
52+
if self.system.endswith('.txt'):
53+
assert os.path.isfile(self.system), f'self.system: {self.system}'
54+
with open(self.system, 'r') as f:
55+
self.system = f.read()
56+
else:
57+
self.system = self.system.replace('\\n', '\n')
58+
if self.response_prefix is not None:
59+
self.response_prefix = self.response_prefix.replace('\\n', '\n')
5560
if self.truncation_strategy is None:
5661
self.truncation_strategy = 'delete'
5762

swift/llm/infer/infer_engine/infer_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _post_init(self):
3535
self.config = self.model_info.config
3636
if getattr(self, 'default_template', None) is None:
3737
ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None))
38+
logger.info('Create the default_template for the infer_engine')
3839
if ckpt_dir:
3940
from swift.llm import BaseArguments
4041
args = BaseArguments.from_pretrained(ckpt_dir)

swift/llm/infer/infer_engine/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def _align_blank_suffix(self, response: str) -> str:
6767

6868
def _get_response(self, response: str, is_finished: bool, token_len: int) -> str:
6969
# After the symbol for a new line, we flush the cache.
70-
if response.endswith('\n') or is_finished:
70+
if self.first_token:
71+
printable_text = response
72+
self.first_token = False
73+
elif response.endswith('\n') or is_finished:
7174
printable_text = response[self.print_idx:]
7275
self.cache_idx += token_len
7376
self.first_num_space = -1
@@ -85,9 +88,10 @@ def _get_response(self, response: str, is_finished: bool, token_len: int) -> str
8588

8689
def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str:
8790
raw_tokens = raw_tokens[self.cache_idx:]
91+
if self.first_token:
92+
raw_tokens = []
8893
response = self.template.decode(
8994
raw_tokens, is_finished=is_finished, tokenizer_kwargs=self.decode_kwargs, first_token=self.first_token)
90-
self.first_token = False
9195
response = self._align_blank_suffix(response)
9296
return self._get_response(response, is_finished, len(raw_tokens))
9397

0 commit comments

Comments
 (0)