Skip to content

Commit 3d2bded

Browse files
authored
[megatron] Fix ref_adapter_load (#5480)
1 parent ed1ea44 commit 3d2bded

File tree

3 files changed

+76
-52
lines changed

3 files changed

+76
-52
lines changed

swift/megatron/trainers/base.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -134,70 +134,95 @@ def sh_ten_merge_fn(sub_state_dict):
134134
if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__:
135135
v.merge_fn = sh_ten_merge_fn
136136

137+
def _load_adapter_base_checkpoint(self, *_args, **kwargs):
138+
adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter'
139+
from megatron.training import checkpointing
140+
sharded_state_dict = kwargs.get('sharded_state_dict')
141+
if sharded_state_dict is None:
142+
return checkpointing.origin__load_base_checkpoint(*_args, **kwargs)
143+
state_dict_model = {}
144+
mapping = {}
145+
for k, v in sharded_state_dict['model'].items():
146+
if adapter_name not in k:
147+
continue
148+
# lora
149+
origin_k = k
150+
k = k.replace(f'.{adapter_name}.', '.default.')
151+
mapping[k] = origin_k
152+
v.key = v.key.replace(f'.{adapter_name}.', '.default.')
153+
state_dict_model[k] = v
154+
sharded_state_dict['model'] = state_dict_model
155+
self._patch_merge_fn(state_dict_model)
156+
res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs)
157+
state_dict = res[0]['model']
158+
for k, origin_k in mapping.items():
159+
v = state_dict.pop(k)
160+
state_dict[origin_k] = v
161+
return res
162+
163+
def _load_base_checkpoint(self, *_args, **kwargs):
164+
from megatron.training import checkpointing
165+
sharded_state_dict = kwargs.get('sharded_state_dict')
166+
if sharded_state_dict is None:
167+
return checkpointing.origin__load_base_checkpoint(*_args, **kwargs)
168+
if self.args.train_type == 'full':
169+
self._patch_merge_fn(sharded_state_dict['model'])
170+
return checkpointing.origin__load_base_checkpoint(*_args, **kwargs)
171+
state_dict_model = {}
172+
mapping = {}
173+
for k, v in sharded_state_dict['model'].items():
174+
if 'lora_A' in k or 'lora_B' in k or 'original_module' in k:
175+
continue
176+
# lora
177+
if '.base_layer' in k:
178+
origin_k = k
179+
k = k.replace('.base_layer', '')
180+
mapping[k] = origin_k
181+
v.key = v.key.replace('.base_layer', '')
182+
elif '.modules_to_save' in k:
183+
if '.modules_to_save.default' not in k:
184+
# e.g. ref_adapter
185+
continue
186+
# modules to save
187+
origin_k = k
188+
k = k.replace('.modules_to_save.default', '')
189+
mapping[k] = origin_k
190+
v.key = v.key.replace('.modules_to_save.default', '')
191+
state_dict_model[k] = v
192+
sharded_state_dict['model'] = state_dict_model
193+
self._patch_merge_fn(state_dict_model)
194+
res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs)
195+
state_dict = res[0]['model']
196+
for k, origin_k in mapping.items():
197+
v = state_dict.pop(k)
198+
state_dict[origin_k] = v
199+
return res
200+
137201
@contextmanager
138-
def _patch_load_state_dict(self):
202+
def _patch_load_state_dict(self, load_base_checkpoint):
139203
from megatron.training import checkpointing
140-
origin__load_base_checkpoint = checkpointing._load_base_checkpoint
204+
checkpointing.origin__load_base_checkpoint = checkpointing._load_base_checkpoint
205+
checkpointing._load_base_checkpoint = load_base_checkpoint
141206

142207
args = get_args()
143208
origin_load_state_dict = torch.nn.Module.load_state_dict
144209
origin_no_load_optim = args.no_load_optim
145210
origin_no_load_rng = args.no_load_rng
146211
origin_finetune = args.finetune
147212

148-
def _load_base_checkpoint(*_args, **kwargs):
149-
sharded_state_dict = kwargs.get('sharded_state_dict')
150-
if sharded_state_dict is None:
151-
return origin__load_base_checkpoint(*_args, **kwargs)
152-
if self.args.train_type == 'full':
153-
self._patch_merge_fn(sharded_state_dict['model'])
154-
return origin__load_base_checkpoint(*_args, **kwargs)
155-
state_dict_model = {}
156-
mapping = {}
157-
for k, v in sharded_state_dict['model'].items():
158-
if 'lora_A' in k or 'lora_B' in k or 'original_module' in k:
159-
continue
160-
# lora
161-
if '.base_layer' in k:
162-
origin_k = k
163-
k = k.replace('.base_layer', '')
164-
mapping[k] = origin_k
165-
v.key = v.key.replace('.base_layer', '')
166-
elif '.modules_to_save' in k:
167-
if '.modules_to_save.default' not in k:
168-
# e.g. ref_adapter
169-
continue
170-
# modules to save
171-
origin_k = k
172-
k = k.replace('.modules_to_save.default', '')
173-
mapping[k] = origin_k
174-
v.key = v.key.replace('.modules_to_save.default', '')
175-
state_dict_model[k] = v
176-
sharded_state_dict['model'] = state_dict_model
177-
self._patch_merge_fn(state_dict_model)
178-
res = origin__load_base_checkpoint(*_args, **kwargs)
179-
state_dict = res[0]['model']
180-
for k, origin_k in mapping.items():
181-
v = state_dict.pop(k)
182-
state_dict[origin_k] = v
183-
return res
184-
185213
def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
186214
strict = False
187215
return origin_load_state_dict(self, state_dict, strict, *args, **kwargs)
188216

189-
checkpointing._load_base_checkpoint = _load_base_checkpoint
190-
191217
if args.train_type != 'full':
192218
torch.nn.Module.load_state_dict = load_state_dict
193219
args.no_load_optim = True
194220
args.no_load_rng = True
195221
args.finetune = True
196-
197222
try:
198223
yield
199224
finally:
200-
checkpointing._load_base_checkpoint = origin__load_base_checkpoint
225+
checkpointing._load_base_checkpoint = checkpointing.origin__load_base_checkpoint
201226
torch.nn.Module.load_state_dict = origin_load_state_dict
202227
args.no_load_optim = origin_no_load_optim
203228
args.no_load_rng = origin_no_load_rng
@@ -210,14 +235,18 @@ def new_model_provider_func(*args, **kwargs):
210235
self.peft_model = prepare_mcore_model(self.unwrapped_model)
211236
return self.unwrapped_model
212237

213-
with self._patch_load_state_dict():
238+
with self._patch_load_state_dict(self._load_base_checkpoint):
214239
model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer(
215240
new_model_provider_func, model_type, *_args, **kwargs)
216241
args = get_args()
217242
if args.initialize_embedding:
218243
self._initialize_embedding(self.unwrapped_model)
219244
if args.train_type != 'full' and args.modules_to_save:
220245
copy_original_module_weight(self.unwrapped_model)
246+
if args.ref_adapter_load is not None:
247+
with self._patch_load_state_dict(self._load_adapter_base_checkpoint):
248+
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
249+
model, optimizer, opt_param_scheduler, load_arg='ref_adapter_load', strict=False)
221250
if args.adapter_load is not None:
222251
with adapter_state_dict_context():
223252
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(

swift/megatron/trainers/dpo_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from swift.trainers import DPOTrainer
1515
from swift.utils import get_current_device, get_logger
16-
from ..utils import copy_ref_adapter_weight
1716
from .trainer import MegatronTrainer
1817
from .utils import get_batch
1918

@@ -52,11 +51,7 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k
5251
self.ref_model.eval()
5352
else:
5453
self.ref_model = None
55-
model, optimizer, opt_param_scheduler = super().setup_model_and_optimizer(model_provider_func, model_type,
56-
*_args, **kwargs)
57-
if args.ref_adapter_load is not None:
58-
copy_ref_adapter_weight(self.unwrapped_model, 'ref_adapter')
59-
return model, optimizer, opt_param_scheduler
54+
return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs)
6055

6156
@staticmethod
6257
def _forward_step_helper(model, inputs):

swift/megatron/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
from .convert import convert_hf2mcore, convert_mcore2hf
44
from .patcher import patch_megatron_tokenizer
5-
from .utils import (adapter_state_dict_context, copy_original_module_weight, copy_ref_adapter_weight,
6-
prepare_mcore_model, tuners_sharded_state_dict)
5+
from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model,
6+
tuners_sharded_state_dict)

0 commit comments

Comments
 (0)