@@ -139,6 +139,12 @@ def _patch_load_state_dict(self):
139
139
from megatron .training import checkpointing
140
140
origin__load_base_checkpoint = checkpointing ._load_base_checkpoint
141
141
142
+ args = get_args ()
143
+ origin_load_state_dict = torch .nn .Module .load_state_dict
144
+ origin_no_load_optim = args .no_load_optim
145
+ origin_no_load_rng = args .no_load_rng
146
+ origin_finetune = args .finetune
147
+
142
148
def _load_base_checkpoint (* _args , ** kwargs ):
143
149
sharded_state_dict = kwargs .get ('sharded_state_dict' )
144
150
if sharded_state_dict is None :
@@ -176,20 +182,17 @@ def _load_base_checkpoint(*_args, **kwargs):
176
182
state_dict [origin_k ] = v
177
183
return res
178
184
179
- origin_load_state_dict = torch .nn .Module .load_state_dict
180
-
181
185
def load_state_dict (self , state_dict , strict : bool = True , * args , ** kwargs ):
182
186
strict = False
183
187
return origin_load_state_dict (self , state_dict , strict , * args , ** kwargs )
184
188
185
189
checkpointing ._load_base_checkpoint = _load_base_checkpoint
186
- torch .nn .Module .load_state_dict = load_state_dict
187
190
188
- args = get_args ()
189
- origin_no_load_optim = args . no_load_optim
190
- origin_no_load_rng = args .no_load_rng
191
- args .no_load_optim = True
192
- args .no_load_rng = True
191
+ if args . train_type != 'full' :
192
+ torch . nn . Module . load_state_dict = load_state_dict
193
+ args .no_load_optim = True
194
+ args .no_load_rng = True
195
+ args .finetune = True
193
196
194
197
try :
195
198
yield
@@ -198,6 +201,7 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
198
201
torch .nn .Module .load_state_dict = origin_load_state_dict
199
202
args .no_load_optim = origin_no_load_optim
200
203
args .no_load_rng = origin_no_load_rng
204
+ args .finetune = origin_finetune
201
205
202
206
def setup_model_and_optimizer (self , model_provider_func , model_type , * _args , ** kwargs ):
203
207
0 commit comments