@@ -134,70 +134,95 @@ def sh_ten_merge_fn(sub_state_dict):
134
134
if isinstance (v , ShardedTensorFactory ) and 'apply_swiglu_sharded_factory' in v .merge_fn .__qualname__ :
135
135
v .merge_fn = sh_ten_merge_fn
136
136
137
+ def _load_adapter_base_checkpoint (self , * _args , ** kwargs ):
138
+ adapter_name = kwargs .pop ('adapter_name' , None ) or 'ref_adapter'
139
+ from megatron .training import checkpointing
140
+ sharded_state_dict = kwargs .get ('sharded_state_dict' )
141
+ if sharded_state_dict is None :
142
+ return checkpointing .origin__load_base_checkpoint (* _args , ** kwargs )
143
+ state_dict_model = {}
144
+ mapping = {}
145
+ for k , v in sharded_state_dict ['model' ].items ():
146
+ if adapter_name not in k :
147
+ continue
148
+ # lora
149
+ origin_k = k
150
+ k = k .replace (f'.{ adapter_name } .' , '.default.' )
151
+ mapping [k ] = origin_k
152
+ v .key = v .key .replace (f'.{ adapter_name } .' , '.default.' )
153
+ state_dict_model [k ] = v
154
+ sharded_state_dict ['model' ] = state_dict_model
155
+ self ._patch_merge_fn (state_dict_model )
156
+ res = checkpointing .origin__load_base_checkpoint (* _args , ** kwargs )
157
+ state_dict = res [0 ]['model' ]
158
+ for k , origin_k in mapping .items ():
159
+ v = state_dict .pop (k )
160
+ state_dict [origin_k ] = v
161
+ return res
162
+
163
+ def _load_base_checkpoint (self , * _args , ** kwargs ):
164
+ from megatron .training import checkpointing
165
+ sharded_state_dict = kwargs .get ('sharded_state_dict' )
166
+ if sharded_state_dict is None :
167
+ return checkpointing .origin__load_base_checkpoint (* _args , ** kwargs )
168
+ if self .args .train_type == 'full' :
169
+ self ._patch_merge_fn (sharded_state_dict ['model' ])
170
+ return checkpointing .origin__load_base_checkpoint (* _args , ** kwargs )
171
+ state_dict_model = {}
172
+ mapping = {}
173
+ for k , v in sharded_state_dict ['model' ].items ():
174
+ if 'lora_A' in k or 'lora_B' in k or 'original_module' in k :
175
+ continue
176
+ # lora
177
+ if '.base_layer' in k :
178
+ origin_k = k
179
+ k = k .replace ('.base_layer' , '' )
180
+ mapping [k ] = origin_k
181
+ v .key = v .key .replace ('.base_layer' , '' )
182
+ elif '.modules_to_save' in k :
183
+ if '.modules_to_save.default' not in k :
184
+ # e.g. ref_adapter
185
+ continue
186
+ # modules to save
187
+ origin_k = k
188
+ k = k .replace ('.modules_to_save.default' , '' )
189
+ mapping [k ] = origin_k
190
+ v .key = v .key .replace ('.modules_to_save.default' , '' )
191
+ state_dict_model [k ] = v
192
+ sharded_state_dict ['model' ] = state_dict_model
193
+ self ._patch_merge_fn (state_dict_model )
194
+ res = checkpointing .origin__load_base_checkpoint (* _args , ** kwargs )
195
+ state_dict = res [0 ]['model' ]
196
+ for k , origin_k in mapping .items ():
197
+ v = state_dict .pop (k )
198
+ state_dict [origin_k ] = v
199
+ return res
200
+
137
201
@contextmanager
138
- def _patch_load_state_dict (self ):
202
+ def _patch_load_state_dict (self , load_base_checkpoint ):
139
203
from megatron .training import checkpointing
140
- origin__load_base_checkpoint = checkpointing ._load_base_checkpoint
204
+ checkpointing .origin__load_base_checkpoint = checkpointing ._load_base_checkpoint
205
+ checkpointing ._load_base_checkpoint = load_base_checkpoint
141
206
142
207
args = get_args ()
143
208
origin_load_state_dict = torch .nn .Module .load_state_dict
144
209
origin_no_load_optim = args .no_load_optim
145
210
origin_no_load_rng = args .no_load_rng
146
211
origin_finetune = args .finetune
147
212
148
- def _load_base_checkpoint (* _args , ** kwargs ):
149
- sharded_state_dict = kwargs .get ('sharded_state_dict' )
150
- if sharded_state_dict is None :
151
- return origin__load_base_checkpoint (* _args , ** kwargs )
152
- if self .args .train_type == 'full' :
153
- self ._patch_merge_fn (sharded_state_dict ['model' ])
154
- return origin__load_base_checkpoint (* _args , ** kwargs )
155
- state_dict_model = {}
156
- mapping = {}
157
- for k , v in sharded_state_dict ['model' ].items ():
158
- if 'lora_A' in k or 'lora_B' in k or 'original_module' in k :
159
- continue
160
- # lora
161
- if '.base_layer' in k :
162
- origin_k = k
163
- k = k .replace ('.base_layer' , '' )
164
- mapping [k ] = origin_k
165
- v .key = v .key .replace ('.base_layer' , '' )
166
- elif '.modules_to_save' in k :
167
- if '.modules_to_save.default' not in k :
168
- # e.g. ref_adapter
169
- continue
170
- # modules to save
171
- origin_k = k
172
- k = k .replace ('.modules_to_save.default' , '' )
173
- mapping [k ] = origin_k
174
- v .key = v .key .replace ('.modules_to_save.default' , '' )
175
- state_dict_model [k ] = v
176
- sharded_state_dict ['model' ] = state_dict_model
177
- self ._patch_merge_fn (state_dict_model )
178
- res = origin__load_base_checkpoint (* _args , ** kwargs )
179
- state_dict = res [0 ]['model' ]
180
- for k , origin_k in mapping .items ():
181
- v = state_dict .pop (k )
182
- state_dict [origin_k ] = v
183
- return res
184
-
185
213
def load_state_dict (self , state_dict , strict : bool = True , * args , ** kwargs ):
186
214
strict = False
187
215
return origin_load_state_dict (self , state_dict , strict , * args , ** kwargs )
188
216
189
- checkpointing ._load_base_checkpoint = _load_base_checkpoint
190
-
191
217
if args .train_type != 'full' :
192
218
torch .nn .Module .load_state_dict = load_state_dict
193
219
args .no_load_optim = True
194
220
args .no_load_rng = True
195
221
args .finetune = True
196
-
197
222
try :
198
223
yield
199
224
finally :
200
- checkpointing ._load_base_checkpoint = origin__load_base_checkpoint
225
+ checkpointing ._load_base_checkpoint = checkpointing . origin__load_base_checkpoint
201
226
torch .nn .Module .load_state_dict = origin_load_state_dict
202
227
args .no_load_optim = origin_no_load_optim
203
228
args .no_load_rng = origin_no_load_rng
@@ -210,14 +235,18 @@ def new_model_provider_func(*args, **kwargs):
210
235
self .peft_model = prepare_mcore_model (self .unwrapped_model )
211
236
return self .unwrapped_model
212
237
213
- with self ._patch_load_state_dict ():
238
+ with self ._patch_load_state_dict (self . _load_base_checkpoint ):
214
239
model , optimizer , opt_param_scheduler = self ._origin_setup_model_and_optimizer (
215
240
new_model_provider_func , model_type , * _args , ** kwargs )
216
241
args = get_args ()
217
242
if args .initialize_embedding :
218
243
self ._initialize_embedding (self .unwrapped_model )
219
244
if args .train_type != 'full' and args .modules_to_save :
220
245
copy_original_module_weight (self .unwrapped_model )
246
+ if args .ref_adapter_load is not None :
247
+ with self ._patch_load_state_dict (self ._load_adapter_base_checkpoint ):
248
+ args .iteration , args .num_floating_point_operations_so_far = load_checkpoint (
249
+ model , optimizer , opt_param_scheduler , load_arg = 'ref_adapter_load' , strict = False )
221
250
if args .adapter_load is not None :
222
251
with adapter_state_dict_context ():
223
252
args .iteration , args .num_floating_point_operations_so_far = load_checkpoint (
0 commit comments