@@ -40,11 +40,17 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4040 # initialize the true module if necessary
4141 model = model .init_modellike () if isinstance (model , ModelLikeModule ) else model
4242
43+ # Freeze all base model weights before replacing modules if freeze_base_model is True
44+ if config .freeze_base_model :
45+ for param in model .parameters ():
46+ param .requires_grad = False
47+
4348 replace_lora_module (model , version = ModeloptStateManager (model ).state_version , config = config )
4449
4550 metadata = {}
4651 add_adapter (model , config )
47- update_grads (model , config )
52+ # Update gradient settings for LoRA parameters only
53+ _update_lora_grads (model , config )
4854
4955 return model , metadata
5056
@@ -169,17 +175,25 @@ def _iter_lora_modules(model, layer_patterns=None):
169175
170176
171177def _set_base_requires_grad (model , * , requires_grad : bool , layer_patterns = None ):
172- for _ , module in _iter_lora_modules (model , layer_patterns ):
173- lora_param_ids = {
174- id (param )
175- for adapter in module ._lora_adapters .values ()
176- for submodule in ("lora_a" , "lora_b" )
177- for _ , param in adapter [submodule ].named_parameters ()
178- }
179- for _ , param in module .named_parameters ():
180- if id (param ) in lora_param_ids :
178+ # Collect all LoRA parameter IDs across the entire model
179+ lora_param_ids = set ()
180+ for _ , module in _iter_lora_modules (model , layer_patterns = None ):
181+ for adapter in module ._lora_adapters .values ():
182+ for submodule in ("lora_a" , "lora_b" ):
183+ for _ , param in adapter [submodule ].named_parameters ():
184+ lora_param_ids .add (id (param ))
185+
186+ # Set requires_grad for all parameters in the model (excluding LoRA parameters)
187+ for name , param in model .named_parameters ():
188+ # Skip LoRA parameters
189+ if id (param ) in lora_param_ids :
190+ continue
191+ # If layer_patterns is specified, only affect matching layers
192+ if layer_patterns is not None :
193+ module_name = "." .join (name .split ("." )[:- 1 ]) # Get module name without param name
194+ if not _matches (module_name , layer_patterns ):
181195 continue
182- param .requires_grad = requires_grad
196+ param .requires_grad = requires_grad
183197
184198
185199def _iter_adapter_names (module , adapter_patterns = None ):
@@ -202,7 +216,8 @@ def _set_lora_requires_grad(
202216def freeze_base_weights (model , * , layer_patterns = None ):
203217 """Freeze base model weights to prevent gradient updates during training.
204218
205- This function sets requires_grad=False for all base model parameters in LoRA modules,
219+ This function sets requires_grad=False for all base model parameters (including
220+ linear weights, embeddings, layer norms, etc.) across the entire model,
206221 while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
207222 only adapter weights should be updated.
208223
@@ -218,8 +233,10 @@ def freeze_base_weights(model, *, layer_patterns=None):
218233def unfreeze_base_weights (model , * , layer_patterns = None ):
219234 """Unfreeze base model weights to allow gradient updates during training.
220235
221- This function sets requires_grad=True for all base model parameters in LoRA modules.
222- Useful when you want to fine-tune both base model and LoRA adapter weights together.
236+ This function sets requires_grad=True for all base model parameters (including
237+ linear weights, embeddings, layer norms, etc.) across the entire model,
238+ while keeping LoRA adapter parameters unchanged. Useful when you want to fine-tune
239+ both base model and LoRA adapter weights together.
223240
224241 Args:
225242 model: Model containing LoRA modules whose base weights should be unfrozen
@@ -277,18 +294,20 @@ def unfreeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
277294 )
278295
279296
280- def update_grads (model , config : PEFTConfig ):
281- """Update gradient computation settings based on PEFTConfig.
297+ def _update_lora_grads (model , config : PEFTConfig ):
298+ """Update gradient computation settings for LoRA parameters only (internal function).
299+
300+ This internal function configures which LoRA adapter parameters should have gradients
301+ computed based on the freeze_lora_weights setting in the PEFTConfig. It's typically
302+ called during model initialization after LoRA adapters have been added.
282303
283- This function configures which model parameters should have gradients computed
284- based on the freeze settings in the PEFTConfig. It's typically called during
285- model initialization or when switching training configurations.
304+ Note: This function only affects LoRA parameters. Base model parameter gradients
305+ should be set separately (e.g., in convert_to_peft_model before LoRA module replacement).
286306
287307 Args:
288308 model: Model containing LoRA modules to configure
289- config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
290- - If config.freeze_base_model is True, base weights will have requires_grad=False
309+ config: PEFTConfig instance with freeze_lora_weights setting
291310 - If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
311+ - If config.freeze_lora_weights is False, LoRA weights will have requires_grad=True
292312 """
293- _set_base_requires_grad (model , requires_grad = not config .freeze_base_model )
294313 _set_lora_requires_grad (model , requires_grad = not config .freeze_lora_weights )
0 commit comments