Skip to content

Commit 178033d

Browse files
authored
fix chatglm3 template bug (#298)
1 parent fa5b3b1 commit 178033d

File tree

7 files changed

+23
-14
lines changed

7 files changed

+23
-14
lines changed

swift/llm/sft.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
181181
greater_is_better=args.predict_with_generate,
182182
sortish_sampler=True,
183183
optim=args.optim,
184+
adam_beta1=args.adam_beta1,
185+
adam_beta2=args.adam_beta2,
184186
hub_model_id=args.hub_model_id,
185187
hub_private_repo=args.hub_private_repo,
186188
push_hub_strategy=args.push_hub_strategy,
@@ -200,7 +202,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
200202
disable_tqdm=args.disable_tqdm,
201203
save_on_each_node=args.save_on_each_node,
202204
acc_strategy=args.acc_strategy,
203-
save_safetensors=args.save_safetensors)
205+
save_safetensors=args.save_safetensors,
206+
logging_first_step=True)
204207

205208
if args.gradient_checkpointing:
206209
model.config.use_cache = False # fix transformers==4.36

swift/llm/utils/argument.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class SftArguments:
104104
# if max_steps >= 0, override num_train_epochs
105105
max_steps: int = -1
106106
optim: str = 'adamw_torch'
107+
adam_beta1: float = 0.9
108+
adam_beta2: float = 0.999
107109
learning_rate: Optional[float] = None
108110
weight_decay: float = 0.01
109111
gradient_accumulation_steps: Optional[int] = None

swift/llm/utils/dataset.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _repair_agent_conversations(conversations: str,
409409

410410
advertise_gen_prompt = """Task: Generating advertisements based on keywords.
411411
Keywords: {query}
412-
Advertisements: """
412+
Advertisements:"""
413413
register_dataset(
414414
DatasetName.advertise_gen_zh,
415415
'lvjianjin/AdvertiseGen', ['train'], ['validation'],
@@ -513,7 +513,7 @@ def _preprocess_dureader_robust(dataset: HfDataset) -> HfDataset:
513513
prompt = """Task: Question Generation
514514
Context: {context}
515515
Answer: {answer}
516-
Question: """
516+
Question:"""
517517
query = []
518518
response = []
519519
for d in dataset:
@@ -850,7 +850,7 @@ def _preprocess_hc3(dataset: HfDataset) -> HfDataset:
850850
Question: {question}
851851
Answer: {answer}
852852
Category: Human, ChatGPT
853-
Output: """
853+
Output:"""
854854
query = []
855855
response = []
856856
for d in dataset:
@@ -978,6 +978,9 @@ def add_self_cognition_dataset(
978978
return concatenate_datasets([train_dataset, dataset])
979979

980980

981+
NoneType = type(None)
982+
983+
981984
def _check_dataset(
982985
dataset: Optional[None],
983986
check_dataset_strategy: Literal['none', 'discard', 'error', 'warning']
@@ -1003,7 +1006,7 @@ def _check_dataset(
10031006
continue
10041007
else:
10051008
raise ValueError(f"d['response']: {d['response']}, i: {i}")
1006-
if has_query and not isinstance(d['response'], str):
1009+
if has_query and not isinstance(d['query'], (str, NoneType)):
10071010
is_modified = True
10081011
if check_dataset_strategy == 'discard':
10091012
continue
@@ -1012,7 +1015,7 @@ def _check_dataset(
10121015
continue
10131016
else:
10141017
raise ValueError(f"d['query']: {d['query']}, i: {i}")
1015-
if has_history and not isinstance(d['history'], (list, type(None))):
1018+
if has_history and not isinstance(d['history'], (list, NoneType)):
10161019
is_modified = True
10171020
if check_dataset_strategy == 'discard':
10181021
continue
@@ -1021,7 +1024,7 @@ def _check_dataset(
10211024
continue
10221025
else:
10231026
raise ValueError(f"d['history']: {d['history']}, i: {i}")
1024-
if has_system and not isinstance(d['system'], str):
1027+
if has_system and not isinstance(d['system'], (str, NoneType)):
10251028
is_modified = True
10261029
if check_dataset_strategy == 'discard':
10271030
continue

swift/llm/utils/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(self, labels: List[str], task_name: str,
232232
self.prompt = f"""Task: {task_name}
233233
{inputs}
234234
Category: {category}
235-
Output: """
235+
Output:"""
236236
self.task_name = task_name
237237
self.is_pair_seq = is_pair_seq
238238

swift/llm/utils/template.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,13 @@ def register_template(template_type: str,
618618

619619
register_template(
620620
TemplateType.chatglm3,
621-
Template([[64790, 64792]], [[64795], '\n {{QUERY}}', [64796], '\n '], [],
621+
Template([[64790, 64792]], [[64795], '\n {{QUERY}}', [64796], '\n'], [],
622622
[['eos_token_id']], None,
623623
[[64790, 64792, 64794], '\n {{SYSTEM}}']))
624624

625625
register_template(
626626
TemplateType.deepseek,
627-
Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant: '],
627+
Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant:'],
628628
[['eos_token_id']], [['eos_token_id']], None,
629629
[['bos_token_id'], '{{SYSTEM}}\n\n']))
630630

@@ -660,7 +660,7 @@ def register_template(template_type: str,
660660
)
661661
register_template(
662662
TemplateType.openbuddy,
663-
Template([['bos_token_id']], ['User: {{QUERY}}\nAssistant: '], ['\n'],
663+
Template([['bos_token_id']], ['User: {{QUERY}}\nAssistant:'], ['\n'],
664664
[['eos_token_id']], OPENBUDDY_DEFAULT_SYSTEM,
665665
[['bos_token_id'], '{{SYSTEM}}\n\n']))
666666

tests/llm/test_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_basic(self):
4242
quantization_bit=quantization_bit,
4343
batch_size=2,
4444
eval_steps=5,
45+
adam_beta2=0.95,
4546
check_dataset_strategy='warning',
4647
train_dataset_sample=200,
4748
predict_with_generate=predict_with_generate,

tests/llm/test_template.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_chatglm3_template(self):
8585
64790, 64792, 64794, 30910, 13, 344, 383, 260, 6483, 9319, 30992,
8686
64795, 30910, 13, 30910, 30939, 30943, 30966, 30972, 30970, 31011,
8787
30943, 30966, 30972, 30980, 31514, 64796
88-
] + [30910, 13, 30910]
88+
] + [30910, 13]
8989
input_ids_swift = template.encode({
9090
'query': query,
9191
'system': system
@@ -439,7 +439,7 @@ def test_openbuddy_template(self):
439439
#
440440
input_ids_official = inputs[0].tolist()
441441
input_ids_swift = template.encode({'query': query})['input_ids']
442-
self.assertTrue(input_ids_swift[:-1] == input_ids_official)
442+
self.assertTrue(input_ids_swift == input_ids_official)
443443
input_ids_swift = template.encode({
444444
'query': query,
445445
'history': [['1234', 'avdc']]
@@ -577,7 +577,7 @@ def test_deepseek_template(self):
577577
response = tokenizer.decode(
578578
outputs[0, len(inputs[0]):], skip_special_tokens=True)
579579
print(f'official response: {response}')
580-
self.assertTrue(input_ids_swift[:-1] == input_ids_official)
580+
self.assertTrue(input_ids_swift == input_ids_official)
581581

582582
@unittest.skipIf(
583583
SKPT_TEST,

0 commit comments

Comments
 (0)