1616"""PEFT conversion and restore utilities for LoRA modules."""
1717
1818import fnmatch
19+ from collections .abc import Callable , Iterable
1920from typing import Any
2021
2122import torch .nn as nn
2829from .lora .layer import LoRAModule , LoRAModuleRegistry
2930
3031__all__ = [
32+ "freeze_base_weights" ,
33+ "freeze_lora_weights" ,
3134 "replace_lora_module" ,
35+ "unfreeze_base_weights" ,
36+ "unfreeze_lora_weights" ,
3237 "update_peft_metadata_in_model" ,
3338]
3439
@@ -42,6 +47,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4247
4348 metadata = {}
4449 add_adapter (model , config )
50+ update_grads (model , config )
4551 update_peft_metadata (model , config , metadata )
4652
4753 return model , metadata
@@ -175,17 +181,177 @@ def add_adapter(model, config: PEFTConfig):
175181 for name , module in model .named_modules ():
176182 if isinstance (module , LoRAModule ):
177183 for wildcard_or_filter_func , adapter_setting in adapter_cfg .items ():
178- if isinstance (wildcard_or_filter_func , str ):
179- if not fnmatch .fnmatch (name , wildcard_or_filter_func ):
180- continue
181- elif callable (wildcard_or_filter_func ):
182- if not wildcard_or_filter_func (name ):
183- continue
184- else :
185- raise NotImplementedError (f"Unsupported type { type (wildcard_or_filter_func )} " )
186- module .update_layer_lora (
187- adapter_name ,
188- adapter_setting ,
189- )
184+ if _matches (name , wildcard_or_filter_func ):
185+ module .update_layer_lora (
186+ adapter_name ,
187+ adapter_setting ,
188+ )
190189
191190 return model
191+
192+
193+ def _matches (
194+ name : str ,
195+ patterns : str | Callable [[str ], bool ] | Iterable [str | Callable [[str ], bool ]] | None ,
196+ * ,
197+ allow_callable : bool = True ,
198+ ) -> bool :
199+ if patterns is None :
200+ return True
201+
202+ if isinstance (patterns , (str , bytes )):
203+ patterns_iter : Iterable [str | Callable [[str ], bool ]] = (patterns ,)
204+ elif callable (patterns ):
205+ if not allow_callable :
206+ raise TypeError ("Callable patterns are not supported in this context." )
207+ patterns_iter = (patterns ,)
208+ elif isinstance (patterns , Iterable ):
209+ patterns_iter = tuple (patterns )
210+ else :
211+ raise TypeError (f"Unsupported pattern type: { type (patterns )} " )
212+
213+ for pattern in patterns_iter :
214+ if isinstance (pattern , (str , bytes )):
215+ if fnmatch .fnmatch (name , pattern ):
216+ return True
217+ elif callable (pattern ):
218+ if not allow_callable :
219+ raise TypeError ("Callable patterns are not supported in this context." )
220+ if pattern (name ):
221+ return True
222+ else :
223+ raise TypeError (f"Unsupported pattern type: { type (pattern )} " )
224+
225+ return False
226+
227+
228+ def _iter_lora_modules (model , layer_patterns = None ):
229+ for module_name , module in model .named_modules ():
230+ if isinstance (module , LoRAModule ) and _matches (module_name , layer_patterns ):
231+ yield module_name , module
232+
233+
234+ def _set_base_requires_grad (model , * , requires_grad : bool , layer_patterns = None ):
235+ for _ , module in _iter_lora_modules (model , layer_patterns ):
236+ lora_param_ids = {
237+ id (param )
238+ for adapter in module ._lora_adapters .values ()
239+ for submodule in ("lora_a" , "lora_b" )
240+ for _ , param in adapter [submodule ].named_parameters ()
241+ }
242+ for _ , param in module .named_parameters ():
243+ if id (param ) in lora_param_ids :
244+ continue
245+ param .requires_grad = requires_grad
246+
247+
248+ def _iter_adapter_names (module , adapter_patterns = None ):
249+ for adapter_name in module ._lora_adapters :
250+ if _matches (adapter_name , adapter_patterns , allow_callable = False ):
251+ yield adapter_name
252+
253+
254+ def _set_lora_requires_grad (
255+ model , * , requires_grad : bool , layer_patterns = None , adapter_patterns = None
256+ ):
257+ for _ , module in _iter_lora_modules (model , layer_patterns ):
258+ for adapter_name in _iter_adapter_names (module , adapter_patterns ):
259+ adapter = module ._lora_adapters [adapter_name ]
260+ for submodule in (adapter ["lora_a" ], adapter ["lora_b" ]):
261+ for _ , param in submodule .named_parameters ():
262+ param .requires_grad = requires_grad
263+
264+
265+ def freeze_base_weights (model , * , layer_patterns = None ):
266+ """Freeze base model weights to prevent gradient updates during training.
267+
268+ This function sets requires_grad=False for all base model parameters in LoRA modules,
269+ while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
270+ only adapter weights should be updated.
271+
272+ Args:
273+ model: Model containing LoRA modules whose base weights should be frozen
274+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
275+ layer names. If provided, only layers matching these patterns will be affected.
276+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
277+ """
278+ _set_base_requires_grad (model , requires_grad = False , layer_patterns = layer_patterns )
279+
280+
281+ def unfreeze_base_weights (model , * , layer_patterns = None ):
282+ """Unfreeze base model weights to allow gradient updates during training.
283+
284+ This function sets requires_grad=True for all base model parameters in LoRA modules.
285+ Useful when you want to fine-tune both base model and LoRA adapter weights together.
286+
287+ Args:
288+ model: Model containing LoRA modules whose base weights should be unfrozen
289+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
290+ layer names. If provided, only layers matching these patterns will be affected.
291+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
292+ """
293+ _set_base_requires_grad (model , requires_grad = True , layer_patterns = layer_patterns )
294+
295+
296+ def freeze_lora_weights (model , * , layer_patterns = None , adapter_patterns = None ):
297+ """Freeze LoRA adapter weights to prevent gradient updates during training.
298+
299+ This function sets requires_grad=False for LoRA adapter parameters (lora_a and lora_b).
300+ Useful when you want to train only the base model weights or evaluate the model
301+ without updating LoRA adapters.
302+
303+ Args:
304+ model: Model containing LoRA modules whose adapter weights should be frozen
305+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
306+ layer names. If provided, only layers matching these patterns will be affected.
307+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
308+ adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
309+ names. If provided, only adapters matching these patterns will be affected.
310+ Supports Unix-style wildcards
311+ """
312+ _set_lora_requires_grad (
313+ model ,
314+ requires_grad = False ,
315+ layer_patterns = layer_patterns ,
316+ adapter_patterns = adapter_patterns ,
317+ )
318+
319+
320+ def unfreeze_lora_weights (model , * , layer_patterns = None , adapter_patterns = None ):
321+ """Unfreeze LoRA adapter weights to allow gradient updates during training.
322+
323+ This function sets requires_grad=True for LoRA adapter parameters (lora_a and lora_b).
324+ This is the typical setting for LoRA fine-tuning where adapter weights are trained.
325+
326+ Args:
327+ model: Model containing LoRA modules whose adapter weights should be unfrozen
328+ layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
329+ layer names. If provided, only layers matching these patterns will be affected.
330+ Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
331+ adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
332+ names. If provided, only adapters matching these patterns will be affected.
333+ Supports Unix-style wildcards
334+ """
335+ _set_lora_requires_grad (
336+ model ,
337+ requires_grad = True ,
338+ layer_patterns = layer_patterns ,
339+ adapter_patterns = adapter_patterns ,
340+ )
341+
342+
343+ def update_grads (model , config : PEFTConfig ):
344+ """Update gradient computation settings based on PEFTConfig.
345+
346+ This function configures which model parameters should have gradients computed
347+ based on the freeze settings in the PEFTConfig. It's typically called during
348+ model initialization or when switching training configurations.
349+
350+ Args:
351+ model: Model containing LoRA modules to configure
352+ config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
353+ - If config.freeze_base_model is True, base weights will have requires_grad=False
354+ - If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
355+ """
356+ _set_base_requires_grad (model , requires_grad = not config .freeze_base_model )
357+ _set_lora_requires_grad (model , requires_grad = not config .freeze_lora_weights )
0 commit comments