@@ -140,6 +140,12 @@ def _patch_load_state_dict(self):
140140 from megatron .training import checkpointing
141141 origin__load_base_checkpoint = checkpointing ._load_base_checkpoint
142142
143+ args = get_args ()
144+ origin_load_state_dict = torch .nn .Module .load_state_dict
145+ origin_no_load_optim = args .no_load_optim
146+ origin_no_load_rng = args .no_load_rng
147+ origin_finetune = args .finetune
148+
143149 def _load_base_checkpoint (* _args , ** kwargs ):
144150 sharded_state_dict = kwargs .get ('sharded_state_dict' )
145151 if sharded_state_dict is None :
@@ -174,20 +180,17 @@ def _load_base_checkpoint(*_args, **kwargs):
174180 state_dict [origin_k ] = v
175181 return res
176182
177- origin_load_state_dict = torch .nn .Module .load_state_dict
178-
179183 def load_state_dict (self , state_dict , strict : bool = True , * args , ** kwargs ):
180184 strict = False
181185 return origin_load_state_dict (self , state_dict , strict , * args , ** kwargs )
182186
183187 checkpointing ._load_base_checkpoint = _load_base_checkpoint
184- torch .nn .Module .load_state_dict = load_state_dict
185188
186- args = get_args ()
187- origin_no_load_optim = args . no_load_optim
188- origin_no_load_rng = args .no_load_rng
189- args .no_load_optim = True
190- args .no_load_rng = True
189+ if args . train_type != 'full' :
190+ torch . nn . Module . load_state_dict = load_state_dict
191+ args .no_load_optim = True
192+ args .no_load_rng = True
193+ args .finetune = True
191194
192195 try :
193196 yield
@@ -196,6 +199,7 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
196199 torch .nn .Module .load_state_dict = origin_load_state_dict
197200 args .no_load_optim = origin_no_load_optim
198201 args .no_load_rng = origin_no_load_rng
202+ args .finetune = origin_finetune
199203
200204 def setup_model_and_optimizer (self , model_provider_func , model_type , * _args , ** kwargs ):
201205
0 commit comments