Skip to content

Commit 8f22319

Browse files
authored
fix megatron model args (#5677)
1 parent aa29050 commit 8f22319

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

swift/llm/argument/base_args/base_args.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,6 @@ def load_args_from_ckpt(self) -> None:
266266
'use_chat_template',
267267
'response_prefix',
268268
]
269-
if 'megatron' in self.__class__.__name__.lower():
270-
force_load_keys = []
271-
load_keys.remove('use_chat_template')
272269
data_keys = list(f.name for f in fields(DataArguments))
273270
for key, old_value in old_args.items():
274271
if old_value is None:

swift/megatron/argument/train_args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
from dataclasses import dataclass
55

6+
import json
7+
68
from swift.llm import BaseArguments
79
from swift.llm.argument.base_args import to_abspath
810
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
@@ -15,6 +17,7 @@
1517
@dataclass
1618
class MegatronTrainArguments(MegatronArguments, BaseArguments):
1719
add_version: bool = True
20+
load_args: bool = False
1821

1922
def init_model_args(self, tokenizer, config):
2023
kwargs = self.megatron_model_meta.convert_hf_config(config)
@@ -42,6 +45,16 @@ def _init_save(self):
4245
if is_master():
4346
os.makedirs(self.save, exist_ok=True)
4447

48+
def _init_ckpt_dir(self, adapters=None):
49+
super()._init_ckpt_dir(adapters)
50+
if self.ckpt_dir and self.model is None:
51+
args_path = os.path.join(self.ckpt_dir, 'args.json')
52+
if not os.path.exists(args_path):
53+
return
54+
with open(args_path, 'r', encoding='utf-8') as f:
55+
old_args = json.load(f)
56+
self.model = old_args.get('model')
57+
4558
def __post_init__(self):
4659
self.sequence_parallel_size = self.context_parallel_size
4760
if self.packing:

0 commit comments

Comments
 (0)