Skip to content

Commit 77c48fe

Browse files
committed
bug fix
Signed-off-by: realAsma <[email protected]>
1 parent 7d612fc commit 77c48fe

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _save_modelopt_state_with_weights(self):
180180
if "kd_loss" not in state and "export_student" not in state
181181
]
182182
modelopt_full_state = {
183-
"modelopt_state_dict": modelopt_state["modelopt_state_dict"],
183+
"modelopt_state": modelopt_state,
184184
"modelopt_state_weights": get_quantizer_state_dict(self.model),
185185
}
186186

@@ -189,7 +189,7 @@ def _save_modelopt_state_with_weights(self):
189189

190190
def _restore_modelopt_state_with_weights(self):
191191
modelopt_full_state = torch.load(self._modelopt_state_path, weights_only=False)
192-
restore_from_modelopt_state(self.model, modelopt_full_state["modelopt_state_dict"])
192+
restore_from_modelopt_state(self.model, modelopt_full_state["modelopt_state"])
193193
set_quantizer_state_dict(self.model, modelopt_full_state["modelopt_state_weights"])
194194

195195
def _quantize_model(self):

0 commit comments

Comments
 (0)