Skip to content

Commit ac164c0

Browse files
committed
Update the grad for loras
Signed-off-by: Jingyu Xin <[email protected]>
1 parent ea167ec commit ac164c0

File tree

2 files changed

+192
-12
lines changed

2 files changed

+192
-12
lines changed

modelopt/torch/peft/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ class PEFTConfig(ModeloptBaseConfig):
145145
validate_default=True,
146146
)
147147

148+
freeze_base_model: bool = ModeloptField(
149+
default=True,
150+
title="Freeze base weights during training",
151+
description="Whether to freeze the base model weights; in most cases, this should be set to True.",
152+
validate_default=True,
153+
)
154+
155+
freeze_lora_weights: bool = ModeloptField(
156+
default=True,
157+
title="Placeholder",
158+
description="Placeholder",
159+
validate_default=True,
160+
)
161+
148162
@field_validator("adapter_type")
149163
@classmethod
150164
def validate_adapter_type(cls, v):

modelopt/torch/peft/conversion.py

Lines changed: 178 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""PEFT conversion and restore utilities for LoRA modules."""
1717

1818
import fnmatch
19+
from collections.abc import Callable, Iterable
1920
from typing import Any
2021

2122
import torch.nn as nn
@@ -28,7 +29,11 @@
2829
from .lora.layer import LoRAModule, LoRAModuleRegistry
2930

3031
__all__ = [
32+
"freeze_base_weights",
33+
"freeze_lora_weights",
3134
"replace_lora_module",
35+
"unfreeze_base_weights",
36+
"unfreeze_lora_weights",
3237
"update_peft_metadata_in_model",
3338
]
3439

@@ -42,6 +47,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4247

4348
metadata = {}
4449
add_adapter(model, config)
50+
update_grads(model, config)
4551
update_peft_metadata(model, config, metadata)
4652

4753
return model, metadata
@@ -175,17 +181,177 @@ def add_adapter(model, config: PEFTConfig):
175181
for name, module in model.named_modules():
176182
if isinstance(module, LoRAModule):
177183
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
178-
if isinstance(wildcard_or_filter_func, str):
179-
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
180-
continue
181-
elif callable(wildcard_or_filter_func):
182-
if not wildcard_or_filter_func(name):
183-
continue
184-
else:
185-
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
186-
module.update_layer_lora(
187-
adapter_name,
188-
adapter_setting,
189-
)
184+
if _matches(name, wildcard_or_filter_func):
185+
module.update_layer_lora(
186+
adapter_name,
187+
adapter_setting,
188+
)
190189

191190
return model
191+
192+
193+
def _matches(
194+
name: str,
195+
patterns: str | Callable[[str], bool] | Iterable[str | Callable[[str], bool]] | None,
196+
*,
197+
allow_callable: bool = True,
198+
) -> bool:
199+
if patterns is None:
200+
return True
201+
202+
if isinstance(patterns, (str, bytes)):
203+
patterns_iter: Iterable[str | Callable[[str], bool]] = (patterns,)
204+
elif callable(patterns):
205+
if not allow_callable:
206+
raise TypeError("Callable patterns are not supported in this context.")
207+
patterns_iter = (patterns,)
208+
elif isinstance(patterns, Iterable):
209+
patterns_iter = tuple(patterns)
210+
else:
211+
raise TypeError(f"Unsupported pattern type: {type(patterns)}")
212+
213+
for pattern in patterns_iter:
214+
if isinstance(pattern, (str, bytes)):
215+
if fnmatch.fnmatch(name, pattern):
216+
return True
217+
elif callable(pattern):
218+
if not allow_callable:
219+
raise TypeError("Callable patterns are not supported in this context.")
220+
if pattern(name):
221+
return True
222+
else:
223+
raise TypeError(f"Unsupported pattern type: {type(pattern)}")
224+
225+
return False
226+
227+
228+
def _iter_lora_modules(model, layer_patterns=None):
229+
for module_name, module in model.named_modules():
230+
if isinstance(module, LoRAModule) and _matches(module_name, layer_patterns):
231+
yield module_name, module
232+
233+
234+
def _set_base_requires_grad(model, *, requires_grad: bool, layer_patterns=None):
235+
for _, module in _iter_lora_modules(model, layer_patterns):
236+
lora_param_ids = {
237+
id(param)
238+
for adapter in module._lora_adapters.values()
239+
for submodule in ("lora_a", "lora_b")
240+
for _, param in adapter[submodule].named_parameters()
241+
}
242+
for _, param in module.named_parameters():
243+
if id(param) in lora_param_ids:
244+
continue
245+
param.requires_grad = requires_grad
246+
247+
248+
def _iter_adapter_names(module, adapter_patterns=None):
249+
for adapter_name in module._lora_adapters:
250+
if _matches(adapter_name, adapter_patterns, allow_callable=False):
251+
yield adapter_name
252+
253+
254+
def _set_lora_requires_grad(
255+
model, *, requires_grad: bool, layer_patterns=None, adapter_patterns=None
256+
):
257+
for _, module in _iter_lora_modules(model, layer_patterns):
258+
for adapter_name in _iter_adapter_names(module, adapter_patterns):
259+
adapter = module._lora_adapters[adapter_name]
260+
for submodule in (adapter["lora_a"], adapter["lora_b"]):
261+
for _, param in submodule.named_parameters():
262+
param.requires_grad = requires_grad
263+
264+
265+
def freeze_base_weights(model, *, layer_patterns=None):
266+
"""Freeze base model weights to prevent gradient updates during training.
267+
268+
This function sets requires_grad=False for all base model parameters in LoRA modules,
269+
while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
270+
only adapter weights should be updated.
271+
272+
Args:
273+
model: Model containing LoRA modules whose base weights should be frozen
274+
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
275+
layer names. If provided, only layers matching these patterns will be affected.
276+
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
277+
"""
278+
_set_base_requires_grad(model, requires_grad=False, layer_patterns=layer_patterns)
279+
280+
281+
def unfreeze_base_weights(model, *, layer_patterns=None):
282+
"""Unfreeze base model weights to allow gradient updates during training.
283+
284+
This function sets requires_grad=True for all base model parameters in LoRA modules.
285+
Useful when you want to fine-tune both base model and LoRA adapter weights together.
286+
287+
Args:
288+
model: Model containing LoRA modules whose base weights should be unfrozen
289+
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
290+
layer names. If provided, only layers matching these patterns will be affected.
291+
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
292+
"""
293+
_set_base_requires_grad(model, requires_grad=True, layer_patterns=layer_patterns)
294+
295+
296+
def freeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
297+
"""Freeze LoRA adapter weights to prevent gradient updates during training.
298+
299+
This function sets requires_grad=False for LoRA adapter parameters (lora_a and lora_b).
300+
Useful when you want to train only the base model weights or evaluate the model
301+
without updating LoRA adapters.
302+
303+
Args:
304+
model: Model containing LoRA modules whose adapter weights should be frozen
305+
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
306+
layer names. If provided, only layers matching these patterns will be affected.
307+
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
308+
adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
309+
names. If provided, only adapters matching these patterns will be affected.
310+
Supports Unix-style wildcards
311+
"""
312+
_set_lora_requires_grad(
313+
model,
314+
requires_grad=False,
315+
layer_patterns=layer_patterns,
316+
adapter_patterns=adapter_patterns,
317+
)
318+
319+
320+
def unfreeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
321+
"""Unfreeze LoRA adapter weights to allow gradient updates during training.
322+
323+
This function sets requires_grad=True for LoRA adapter parameters (lora_a and lora_b).
324+
This is the typical setting for LoRA fine-tuning where adapter weights are trained.
325+
326+
Args:
327+
model: Model containing LoRA modules whose adapter weights should be unfrozen
328+
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
329+
layer names. If provided, only layers matching these patterns will be affected.
330+
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
331+
adapter_patterns: Optional patterns (str or Iterable) to match specific adapter
332+
names. If provided, only adapters matching these patterns will be affected.
333+
Supports Unix-style wildcards
334+
"""
335+
_set_lora_requires_grad(
336+
model,
337+
requires_grad=True,
338+
layer_patterns=layer_patterns,
339+
adapter_patterns=adapter_patterns,
340+
)
341+
342+
343+
def update_grads(model, config: PEFTConfig):
344+
"""Update gradient computation settings based on PEFTConfig.
345+
346+
This function configures which model parameters should have gradients computed
347+
based on the freeze settings in the PEFTConfig. It's typically called during
348+
model initialization or when switching training configurations.
349+
350+
Args:
351+
model: Model containing LoRA modules to configure
352+
config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
353+
- If config.freeze_base_model is True, base weights will have requires_grad=False
354+
- If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
355+
"""
356+
_set_base_requires_grad(model, requires_grad=not config.freeze_base_model)
357+
_set_lora_requires_grad(model, requires_grad=not config.freeze_lora_weights)

0 commit comments

Comments
 (0)