16
16
"""PEFT conversion and restore utilities for LoRA modules."""
17
17
18
18
import fnmatch
19
+ from collections .abc import Callable , Iterable
19
20
from typing import Any
20
21
21
22
import torch .nn as nn
28
29
from .lora .layer import LoRAModule , LoRAModuleRegistry
29
30
30
31
__all__ = [
32
+ "freeze_base_weights" ,
33
+ "freeze_lora_weights" ,
31
34
"replace_lora_module" ,
35
+ "unfreeze_base_weights" ,
36
+ "unfreeze_lora_weights" ,
32
37
"update_peft_metadata_in_model" ,
33
38
]
34
39
@@ -42,6 +47,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
42
47
43
48
metadata = {}
44
49
add_adapter (model , config )
50
+ update_grads (model , config )
45
51
update_peft_metadata (model , config , metadata )
46
52
47
53
return model , metadata
@@ -175,17 +181,177 @@ def add_adapter(model, config: PEFTConfig):
175
181
for name , module in model .named_modules ():
176
182
if isinstance (module , LoRAModule ):
177
183
for wildcard_or_filter_func , adapter_setting in adapter_cfg .items ():
178
- if isinstance (wildcard_or_filter_func , str ):
179
- if not fnmatch .fnmatch (name , wildcard_or_filter_func ):
180
- continue
181
- elif callable (wildcard_or_filter_func ):
182
- if not wildcard_or_filter_func (name ):
183
- continue
184
- else :
185
- raise NotImplementedError (f"Unsupported type { type (wildcard_or_filter_func )} " )
186
- module .update_layer_lora (
187
- adapter_name ,
188
- adapter_setting ,
189
- )
184
+ if _matches (name , wildcard_or_filter_func ):
185
+ module .update_layer_lora (
186
+ adapter_name ,
187
+ adapter_setting ,
188
+ )
190
189
191
190
return model
191
+
192
+
193
+ def _matches (
194
+ name : str ,
195
+ patterns : str | Callable [[str ], bool ] | Iterable [str | Callable [[str ], bool ]] | None ,
196
+ * ,
197
+ allow_callable : bool = True ,
198
+ ) -> bool :
199
+ if patterns is None :
200
+ return True
201
+
202
+ if isinstance (patterns , (str , bytes )):
203
+ patterns_iter : Iterable [str | Callable [[str ], bool ]] = (patterns ,)
204
+ elif callable (patterns ):
205
+ if not allow_callable :
206
+ raise TypeError ("Callable patterns are not supported in this context." )
207
+ patterns_iter = (patterns ,)
208
+ elif isinstance (patterns , Iterable ):
209
+ patterns_iter = tuple (patterns )
210
+ else :
211
+ raise TypeError (f"Unsupported pattern type: { type (patterns )} " )
212
+
213
+ for pattern in patterns_iter :
214
+ if isinstance (pattern , (str , bytes )):
215
+ if fnmatch .fnmatch (name , pattern ):
216
+ return True
217
+ elif callable (pattern ):
218
+ if not allow_callable :
219
+ raise TypeError ("Callable patterns are not supported in this context." )
220
+ if pattern (name ):
221
+ return True
222
+ else :
223
+ raise TypeError (f"Unsupported pattern type: { type (pattern )} " )
224
+
225
+ return False
226
+
227
+
228
+ def _iter_lora_modules (model , layer_patterns = None ):
229
+ for module_name , module in model .named_modules ():
230
+ if isinstance (module , LoRAModule ) and _matches (module_name , layer_patterns ):
231
+ yield module_name , module
232
+
233
+
234
+ def _set_base_requires_grad (model , * , requires_grad : bool , layer_patterns = None ):
235
+ for _ , module in _iter_lora_modules (model , layer_patterns ):
236
+ lora_param_ids = {
237
+ id (param )
238
+ for adapter in module ._lora_adapters .values ()
239
+ for submodule in ("lora_a" , "lora_b" )
240
+ for _ , param in adapter [submodule ].named_parameters ()
241
+ }
242
+ for _ , param in module .named_parameters ():
243
+ if id (param ) in lora_param_ids :
244
+ continue
245
+ param .requires_grad = requires_grad
246
+
247
+
248
+ def _iter_adapter_names (module , adapter_patterns = None ):
249
+ for adapter_name in module ._lora_adapters :
250
+ if _matches (adapter_name , adapter_patterns , allow_callable = False ):
251
+ yield adapter_name
252
+
253
+
254
+ def _set_lora_requires_grad (
255
+ model , * , requires_grad : bool , layer_patterns = None , adapter_patterns = None
256
+ ):
257
+ for _ , module in _iter_lora_modules (model , layer_patterns ):
258
+ for adapter_name in _iter_adapter_names (module , adapter_patterns ):
259
+ adapter = module ._lora_adapters [adapter_name ]
260
+ for submodule in (adapter ["lora_a" ], adapter ["lora_b" ]):
261
+ for _ , param in submodule .named_parameters ():
262
+ param .requires_grad = requires_grad
263
+
264
+
265
+ def freeze_base_weights (model , * , layer_patterns = None ):
266
+ """Freeze base model weights to prevent gradient updates during training.
267
+
268
+ This function sets requires_grad=False for all base model parameters in LoRA modules,
269
+ while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
270
+ only adapter weights should be updated.
271
+
272
+ Args:
273
+ model: Model containing LoRA modules whose base weights should be frozen
274
+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
275
+ layer names. If provided, only layers matching these patterns will be affected.
276
+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
277
+ """
278
+ _set_base_requires_grad (model , requires_grad = False , layer_patterns = layer_patterns )
279
+
280
+
281
+ def unfreeze_base_weights (model , * , layer_patterns = None ):
282
+ """Unfreeze base model weights to allow gradient updates during training.
283
+
284
+ This function sets requires_grad=True for all base model parameters in LoRA modules.
285
+ Useful when you want to fine-tune both base model and LoRA adapter weights together.
286
+
287
+ Args:
288
+ model: Model containing LoRA modules whose base weights should be unfrozen
289
+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
290
+ layer names. If provided, only layers matching these patterns will be affected.
291
+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
292
+ """
293
+ _set_base_requires_grad (model , requires_grad = True , layer_patterns = layer_patterns )
294
+
295
+
296
+ def freeze_lora_weights (model , * , layer_patterns = None , adapter_patterns = None ):
297
+ """Freeze LoRA adapter weights to prevent gradient updates during training.
298
+
299
+ This function sets requires_grad=False for LoRA adapter parameters (lora_a and lora_b).
300
+ Useful when you want to train only the base model weights or evaluate the model
301
+ without updating LoRA adapters.
302
+
303
+ Args:
304
+ model: Model containing LoRA modules whose adapter weights should be frozen
305
+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
306
+ layer names. If provided, only layers matching these patterns will be affected.
307
+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
308
+ adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
309
+ names. If provided, only adapters matching these patterns will be affected.
310
+ Supports Unix-style wildcards
311
+ """
312
+ _set_lora_requires_grad (
313
+ model ,
314
+ requires_grad = False ,
315
+ layer_patterns = layer_patterns ,
316
+ adapter_patterns = adapter_patterns ,
317
+ )
318
+
319
+
320
+ def unfreeze_lora_weights (model , * , layer_patterns = None , adapter_patterns = None ):
321
+ """Unfreeze LoRA adapter weights to allow gradient updates during training.
322
+
323
+ This function sets requires_grad=True for LoRA adapter parameters (lora_a and lora_b).
324
+ This is the typical setting for LoRA fine-tuning where adapter weights are trained.
325
+
326
+ Args:
327
+ model: Model containing LoRA modules whose adapter weights should be unfrozen
328
+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
329
+ layer names. If provided, only layers matching these patterns will be affected.
330
+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
331
+ adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
332
+ names. If provided, only adapters matching these patterns will be affected.
333
+ Supports Unix-style wildcards
334
+ """
335
+ _set_lora_requires_grad (
336
+ model ,
337
+ requires_grad = True ,
338
+ layer_patterns = layer_patterns ,
339
+ adapter_patterns = adapter_patterns ,
340
+ )
341
+
342
+
343
+ def update_grads (model , config : PEFTConfig ):
344
+ """Update gradient computation settings based on PEFTConfig.
345
+
346
+ This function configures which model parameters should have gradients computed
347
+ based on the freeze settings in the PEFTConfig. It's typically called during
348
+ model initialization or when switching training configurations.
349
+
350
+ Args:
351
+ model: Model containing LoRA modules to configure
352
+ config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
353
+ - If config.freeze_base_model is True, base weights will have requires_grad=False
354
+ - If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
355
+ """
356
+ _set_base_requires_grad (model , requires_grad = not config .freeze_base_model )
357
+ _set_lora_requires_grad (model , requires_grad = not config .freeze_lora_weights )
0 commit comments