@@ -137,46 +137,51 @@ def apply_recipes(self, epoch=0.0):
137137 Applies all recipes from checkpoint_recipes. Runs architecture changing
138138 modifiers to prepare model for state dict loading
139139 """
140+ # get state dict before recipe application
141+ org_state_dict = self .model .state_dict ()
142+
143+ # apply any checkpoint recipes
140144 for checkpoint_recipe in self .checkpoint_recipes :
141145 if checkpoint_recipe is not None :
142146 ScheduledModifierManager .from_yaml (checkpoint_recipe ).apply (self .model )
147+
148+ # init current training recipe
143149 if self .manager is not None :
144- org_state_dict = self .model .state_dict ()
145150 self .manager .initialize (
146151 self .model ,
147152 epoch = epoch ,
148153 distillation_teacher = self .teacher ,
149154 loggers = self .loggers ,
150155 )
151- new_state_dict = self .model .state_dict ()
152- new_params = [p for p in new_state_dict .keys () if p not in org_state_dict ]
153-
154- if os .path .isdir (self .model_name_or_path ):
155- if os .path .isfile (os .path .join (self .model_name_or_path , WEIGHTS_NAME )):
156- archive_file = os .path .join (self .model_name_or_path , WEIGHTS_NAME )
157- state_dict = torch .load (archive_file , map_location = "cpu" )
158- new_params_to_init = [
159- p for p in new_params if p in state_dict .keys ()
160- ]
161- if new_params_to_init :
162- # parameters from dict are dependent on recipe
163- (
164- _ ,
165- missing_keys ,
166- unexpected_keys ,
167- _ ,
168- ) = self .model ._load_state_dict_into_model (
169- self .model ,
170- state_dict ,
171- self .model_name_or_path ,
172- _fast_init = False ,
156+
157+ # if model structure changed, load in new params from state dict
158+ new_state_dict = self .model .state_dict ()
159+ new_params = [p for p in new_state_dict .keys () if p not in org_state_dict ]
160+
161+ if os .path .isdir (self .model_name_or_path ):
162+ if os .path .isfile (os .path .join (self .model_name_or_path , WEIGHTS_NAME )):
163+ archive_file = os .path .join (self .model_name_or_path , WEIGHTS_NAME )
164+ state_dict = torch .load (archive_file , map_location = "cpu" )
165+ new_params_to_init = [p for p in new_params if p in state_dict .keys ()]
166+ if new_params_to_init :
167+ # parameters from dict are dependent on recipe
168+ (
169+ _ ,
170+ missing_keys ,
171+ unexpected_keys ,
172+ _ ,
173+ ) = self .model ._load_state_dict_into_model (
174+ self .model ,
175+ state_dict ,
176+ self .model_name_or_path ,
177+ _fast_init = False ,
178+ )
179+ if missing_keys or unexpected_keys :
180+ raise RuntimeError (
181+ "Unexpected or missing keys detected when applying "
182+ f"recipes to models\n Missing keys: { missing_keys } \n "
183+ f"Unexpected keys: { unexpected_keys } \n "
173184 )
174- if missing_keys or unexpected_keys :
175- raise RuntimeError (
176- "Unexpected or missing keys detected when applying "
177- f"recipes to models\n Missing keys: { missing_keys } \n "
178- f"Unexpected keys: { unexpected_keys } \n "
179- )
180185
181186 def create_optimizer (self ):
182187 """
0 commit comments