Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d1b0622

Browse files
authored
fix load state dict for transformers eval (#534) (#535)
1 parent 3463202 commit d1b0622

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

src/sparseml/transformers/utils/trainer.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -137,46 +137,51 @@ def apply_recipes(self, epoch=0.0):
137137
Applies all recipes from checkpoint_recipes. Runs architecture changing
138138
modifiers to prepare model for state dict loading
139139
"""
140+
# get state dict before recipe application
141+
org_state_dict = self.model.state_dict()
142+
143+
# apply any checkpoint recipes
140144
for checkpoint_recipe in self.checkpoint_recipes:
141145
if checkpoint_recipe is not None:
142146
ScheduledModifierManager.from_yaml(checkpoint_recipe).apply(self.model)
147+
148+
# init current training recipe
143149
if self.manager is not None:
144-
org_state_dict = self.model.state_dict()
145150
self.manager.initialize(
146151
self.model,
147152
epoch=epoch,
148153
distillation_teacher=self.teacher,
149154
loggers=self.loggers,
150155
)
151-
new_state_dict = self.model.state_dict()
152-
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]
153-
154-
if os.path.isdir(self.model_name_or_path):
155-
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
156-
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
157-
state_dict = torch.load(archive_file, map_location="cpu")
158-
new_params_to_init = [
159-
p for p in new_params if p in state_dict.keys()
160-
]
161-
if new_params_to_init:
162-
# parameters from dict are dependent on recipe
163-
(
164-
_,
165-
missing_keys,
166-
unexpected_keys,
167-
_,
168-
) = self.model._load_state_dict_into_model(
169-
self.model,
170-
state_dict,
171-
self.model_name_or_path,
172-
_fast_init=False,
156+
157+
# if model structure changed, load in new params from state dict
158+
new_state_dict = self.model.state_dict()
159+
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]
160+
161+
if os.path.isdir(self.model_name_or_path):
162+
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
163+
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
164+
state_dict = torch.load(archive_file, map_location="cpu")
165+
new_params_to_init = [p for p in new_params if p in state_dict.keys()]
166+
if new_params_to_init:
167+
# parameters from dict are dependent on recipe
168+
(
169+
_,
170+
missing_keys,
171+
unexpected_keys,
172+
_,
173+
) = self.model._load_state_dict_into_model(
174+
self.model,
175+
state_dict,
176+
self.model_name_or_path,
177+
_fast_init=False,
178+
)
179+
if missing_keys or unexpected_keys:
180+
raise RuntimeError(
181+
"Unexpected or missing keys detected when applying "
182+
f"recipes to models\nMissing keys: {missing_keys}\n"
183+
f"Unexpected keys: {unexpected_keys}\n"
173184
)
174-
if missing_keys or unexpected_keys:
175-
raise RuntimeError(
176-
"Unexpected or missing keys detected when applying "
177-
f"recipes to models\nMissing keys: {missing_keys}\n"
178-
f"Unexpected keys: {unexpected_keys}\n"
179-
)
180185

181186
def create_optimizer(self):
182187
"""

0 commit comments

Comments
 (0)