@@ -40,11 +40,17 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
40
40
# initialize the true module if necessary
41
41
model = model .init_modellike () if isinstance (model , ModelLikeModule ) else model
42
42
43
+ # Freeze all base model weights before replacing modules if freeze_base_model is True
44
+ if config .freeze_base_model :
45
+ for param in model .parameters ():
46
+ param .requires_grad = False
47
+
43
48
replace_lora_module (model , version = ModeloptStateManager (model ).state_version , config = config )
44
49
45
50
metadata = {}
46
51
add_adapter (model , config )
47
- update_grads (model , config )
52
+ # Update gradient settings for LoRA parameters only
53
+ _update_lora_grads (model , config )
48
54
49
55
return model , metadata
50
56
@@ -169,17 +175,25 @@ def _iter_lora_modules(model, layer_patterns=None):
169
175
170
176
171
177
def _set_base_requires_grad (model , * , requires_grad : bool , layer_patterns = None ):
172
- for _ , module in _iter_lora_modules (model , layer_patterns ):
173
- lora_param_ids = {
174
- id (param )
175
- for adapter in module ._lora_adapters .values ()
176
- for submodule in ("lora_a" , "lora_b" )
177
- for _ , param in adapter [submodule ].named_parameters ()
178
- }
179
- for _ , param in module .named_parameters ():
180
- if id (param ) in lora_param_ids :
178
+ # Collect all LoRA parameter IDs across the entire model
179
+ lora_param_ids = set ()
180
+ for _ , module in _iter_lora_modules (model , layer_patterns = None ):
181
+ for adapter in module ._lora_adapters .values ():
182
+ for submodule in ("lora_a" , "lora_b" ):
183
+ for _ , param in adapter [submodule ].named_parameters ():
184
+ lora_param_ids .add (id (param ))
185
+
186
+ # Set requires_grad for all parameters in the model (excluding LoRA parameters)
187
+ for name , param in model .named_parameters ():
188
+ # Skip LoRA parameters
189
+ if id (param ) in lora_param_ids :
190
+ continue
191
+ # If layer_patterns is specified, only affect matching layers
192
+ if layer_patterns is not None :
193
+ module_name = "." .join (name .split ("." )[:- 1 ]) # Get module name without param name
194
+ if not _matches (module_name , layer_patterns ):
181
195
continue
182
- param .requires_grad = requires_grad
196
+ param .requires_grad = requires_grad
183
197
184
198
185
199
def _iter_adapter_names (module , adapter_patterns = None ):
@@ -202,7 +216,8 @@ def _set_lora_requires_grad(
202
216
def freeze_base_weights (model , * , layer_patterns = None ):
203
217
"""Freeze base model weights to prevent gradient updates during training.
204
218
205
- This function sets requires_grad=False for all base model parameters in LoRA modules,
219
+ This function sets requires_grad=False for all base model parameters (including
220
+ linear weights, embeddings, layer norms, etc.) across the entire model,
206
221
while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
207
222
only adapter weights should be updated.
208
223
@@ -218,8 +233,10 @@ def freeze_base_weights(model, *, layer_patterns=None):
218
233
def unfreeze_base_weights (model , * , layer_patterns = None ):
219
234
"""Unfreeze base model weights to allow gradient updates during training.
220
235
221
- This function sets requires_grad=True for all base model parameters in LoRA modules.
222
- Useful when you want to fine-tune both base model and LoRA adapter weights together.
236
+ This function sets requires_grad=True for all base model parameters (including
237
+ linear weights, embeddings, layer norms, etc.) across the entire model,
238
+ while keeping LoRA adapter parameters unchanged. Useful when you want to fine-tune
239
+ both base model and LoRA adapter weights together.
223
240
224
241
Args:
225
242
model: Model containing LoRA modules whose base weights should be unfrozen
@@ -277,18 +294,20 @@ def unfreeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
277
294
)
278
295
279
296
280
- def update_grads (model , config : PEFTConfig ):
281
- """Update gradient computation settings based on PEFTConfig.
297
+ def _update_lora_grads (model , config : PEFTConfig ):
298
+ """Update gradient computation settings for LoRA parameters only (internal function).
299
+
300
+ This internal function configures which LoRA adapter parameters should have gradients
301
+ computed based on the freeze_lora_weights setting in the PEFTConfig. It's typically
302
+ called during model initialization after LoRA adapters have been added.
282
303
283
- This function configures which model parameters should have gradients computed
284
- based on the freeze settings in the PEFTConfig. It's typically called during
285
- model initialization or when switching training configurations.
304
+ Note: This function only affects LoRA parameters. Base model parameter gradients
305
+ should be set separately (e.g., in convert_to_peft_model before LoRA module replacement).
286
306
287
307
Args:
288
308
model: Model containing LoRA modules to configure
289
- config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
290
- - If config.freeze_base_model is True, base weights will have requires_grad=False
309
+ config: PEFTConfig instance with freeze_lora_weights setting
291
310
- If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
311
+ - If config.freeze_lora_weights is False, LoRA weights will have requires_grad=True
292
312
"""
293
- _set_base_requires_grad (model , requires_grad = not config .freeze_base_model )
294
313
_set_lora_requires_grad (model , requires_grad = not config .freeze_lora_weights )
0 commit comments