Skip to content

Commit c6fb911

Browse files
committed
[bugfix] fix megatron load/finetune (#5481)
1 parent 5cc33a0 commit c6fb911

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

swift/megatron/trainers/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def _patch_load_state_dict(self):
140140
from megatron.training import checkpointing
141141
origin__load_base_checkpoint = checkpointing._load_base_checkpoint
142142

143+
args = get_args()
144+
origin_load_state_dict = torch.nn.Module.load_state_dict
145+
origin_no_load_optim = args.no_load_optim
146+
origin_no_load_rng = args.no_load_rng
147+
origin_finetune = args.finetune
148+
143149
def _load_base_checkpoint(*_args, **kwargs):
144150
sharded_state_dict = kwargs.get('sharded_state_dict')
145151
if sharded_state_dict is None:
@@ -174,20 +180,17 @@ def _load_base_checkpoint(*_args, **kwargs):
174180
state_dict[origin_k] = v
175181
return res
176182

177-
origin_load_state_dict = torch.nn.Module.load_state_dict
178-
179183
def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
180184
strict = False
181185
return origin_load_state_dict(self, state_dict, strict, *args, **kwargs)
182186

183187
checkpointing._load_base_checkpoint = _load_base_checkpoint
184-
torch.nn.Module.load_state_dict = load_state_dict
185188

186-
args = get_args()
187-
origin_no_load_optim = args.no_load_optim
188-
origin_no_load_rng = args.no_load_rng
189-
args.no_load_optim = True
190-
args.no_load_rng = True
189+
if args.train_type != 'full':
190+
torch.nn.Module.load_state_dict = load_state_dict
191+
args.no_load_optim = True
192+
args.no_load_rng = True
193+
args.finetune = True
191194

192195
try:
193196
yield
@@ -196,6 +199,7 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
196199
torch.nn.Module.load_state_dict = origin_load_state_dict
197200
args.no_load_optim = origin_no_load_optim
198201
args.no_load_rng = origin_no_load_rng
202+
args.finetune = origin_finetune
199203

200204
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
201205

swift/megatron/utils/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float
143143
'no_save_rng': True,
144144
'no_load_optim': True,
145145
'no_load_rng': True,
146+
'finetune': True,
146147
'attention_backend': 'unfused',
147148
}
148149

@@ -217,6 +218,8 @@ def convert_mcore2hf(args: ExportArguments) -> None:
217218
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args)
218219

219220
mg_model = megatron_model_meta.model_provider()
221+
if megatron_args.load is None:
222+
raise ValueError('Please specify `--mcore_model`.')
220223
load_checkpoint([mg_model], None, None, strict=True)
221224
if megatron_args.adapter_load is not None:
222225
peft_model = prepare_mcore_model(mg_model)

0 commit comments

Comments
 (0)