Skip to content

Commit 339e0ff

Browse files
authored
Fix infer default dtype (#834)
1 parent 08a9552 commit 339e0ff

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

swift/llm/utils/argument.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,8 @@ def load_from_ckpt_dir(self) -> None:
11031103

11041104
if self.model_id_or_path is None:
11051105
self.model_id_or_path = sft_args.get('model_id_or_path')
1106+
if self.dtype == 'AUTO':
1107+
self.dtype = sft_args.get('dtype')
11061108

11071109
@staticmethod
11081110
def check_ckpt_dir_correct(ckpt_dir) -> bool:

tests/llm/test_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def test_basic(self):
6565
train_dataset_sample=200,
6666
predict_with_generate=predict_with_generate,
6767
dataset=[DatasetName.jd_sentiment_zh],
68+
include_num_input_tokens_seen=True,
6869
output_dir=output_dir,
6970
gradient_checkpointing=True)
7071
self.assertTrue(sft_args.gradient_accumulation_steps == 8)

0 commit comments

Comments
 (0)