Skip to content

Commit ed1ea44

Browse files
authored
[bugfix] fix megatron load/finetune (#5481)
1 parent fa3d2d6 commit ed1ea44

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

swift/megatron/trainers/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def _patch_load_state_dict(self):
139139
from megatron.training import checkpointing
140140
origin__load_base_checkpoint = checkpointing._load_base_checkpoint
141141

142+
args = get_args()
143+
origin_load_state_dict = torch.nn.Module.load_state_dict
144+
origin_no_load_optim = args.no_load_optim
145+
origin_no_load_rng = args.no_load_rng
146+
origin_finetune = args.finetune
147+
142148
def _load_base_checkpoint(*_args, **kwargs):
143149
sharded_state_dict = kwargs.get('sharded_state_dict')
144150
if sharded_state_dict is None:
@@ -176,20 +182,17 @@ def _load_base_checkpoint(*_args, **kwargs):
176182
state_dict[origin_k] = v
177183
return res
178184

179-
origin_load_state_dict = torch.nn.Module.load_state_dict
180-
181185
def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
182186
strict = False
183187
return origin_load_state_dict(self, state_dict, strict, *args, **kwargs)
184188

185189
checkpointing._load_base_checkpoint = _load_base_checkpoint
186-
torch.nn.Module.load_state_dict = load_state_dict
187190

188-
args = get_args()
189-
origin_no_load_optim = args.no_load_optim
190-
origin_no_load_rng = args.no_load_rng
191-
args.no_load_optim = True
192-
args.no_load_rng = True
191+
if args.train_type != 'full':
192+
torch.nn.Module.load_state_dict = load_state_dict
193+
args.no_load_optim = True
194+
args.no_load_rng = True
195+
args.finetune = True
193196

194197
try:
195198
yield
@@ -198,6 +201,7 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
198201
torch.nn.Module.load_state_dict = origin_load_state_dict
199202
args.no_load_optim = origin_no_load_optim
200203
args.no_load_rng = origin_no_load_rng
204+
args.finetune = origin_finetune
201205

202206
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
203207

swift/megatron/utils/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float
143143
'no_save_rng': True,
144144
'no_load_optim': True,
145145
'no_load_rng': True,
146+
'finetune': True,
146147
'attention_backend': 'unfused',
147148
}
148149

@@ -219,6 +220,8 @@ def convert_mcore2hf(args: ExportArguments) -> None:
219220
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args)
220221

221222
mg_model = megatron_model_meta.model_provider()
223+
if megatron_args.load is None:
224+
raise ValueError('Please specify `--mcore_model`.')
222225
load_checkpoint([mg_model], None, None, strict=True)
223226
if megatron_args.adapter_load is not None:
224227
peft_model = prepare_mcore_model(mg_model)

0 commit comments

Comments
 (0)