Skip to content

Commit 9272369

Browse files
committed
added check for frozen base model
Signed-off-by: Suguna Velury <[email protected]>
1 parent 1853f43 commit 9272369

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def _load_best_model(self, *args, **kwargs):
283283
if is_lora and not self.is_fsdp_enabled:
284284
# Custom logic for loading best model with LoRA
285285
# TODO: Remove once we migrate to using get_peft_model()
286+
# This custom logic only loads best adapters. Ensure base model is frozen
287+
assert all(
288+
param.requires_grad is False for param in self.model.base_model.parameters()
289+
), "Base model must be frozen for lora"
286290
adapter_name = self.model.active_adapter()
287291
self.model.delete_adapter(adapter_name)
288292
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)

0 commit comments

Comments
 (0)