Skip to content

Commit d6b2e60

Browse files
committed
Add disable_adapters enable_adpaters support, removed some codes
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 935b07b commit d6b2e60

File tree

2 files changed

+151
-23
lines changed

2 files changed

+151
-23
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,3 @@ def add_adapter(model, config: PEFTConfig):
191191
)
192192

193193
return model
194-
195-
196-
def _update_peft_metadata_in_state(model: nn.Module) -> None:
197-
"""Update the PEFT metadata in the ModeloptStateManager.
198-
199-
This function updates the metadata to reflect the current state of LoRA adapters
200-
after they have been added or modified.
201-
"""
202-
if not ModeloptStateManager.is_converted(model):
203-
return
204-
205-
manager = ModeloptStateManager(model)
206-
207-
current_peft_state = {}
208-
for name, module in model.named_modules():
209-
if isinstance(module, LoRAModule):
210-
from modelopt.torch.utils import get_unwrapped_name
211-
212-
unwrapped_name = get_unwrapped_name(name)
213-
current_peft_state[unwrapped_name] = module.get_peft_state()
214-
215-
if manager._state and manager._last_metadata is not None:
216-
manager._last_metadata["peft_state"] = current_peft_state

modelopt/torch/peft/convert.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""User-facing PEFT API for LoRA module conversion and adapter management."""
1717

18+
import fnmatch
1819
from typing import Any
1920

2021
import torch.nn as nn
@@ -26,6 +27,14 @@
2627
from .lora.layer import LoRAModule
2728
from .mode import PEFTModeRegistry
2829

30+
__all__ = [
31+
"disable_adapters",
32+
"enable_adapters",
33+
"get_adapter_states",
34+
"is_peft_model",
35+
"update_model",
36+
]
37+
2938

3039
def update_model(
3140
model: nn.Module,
@@ -67,3 +76,145 @@ def is_peft_model(model: nn.Module) -> bool:
6776
True if the model contains LoRA modules, False otherwise
6877
"""
6978
return any(isinstance(module, LoRAModule) for _, module in model.named_modules())
79+
80+
81+
def _set_adapter_state(model, enable_state, layer_patterns=None, adapter_patterns=None):
82+
"""Helper function to set adapter states.
83+
84+
Args:
85+
model: Model with LoRA adapters
86+
enable_state: Boolean state to set for matching adapters
87+
layer_patterns: Optional list of layer name patterns (wildcards or callables)
88+
adapter_patterns: Optional list of adapter name patterns (wildcards)
89+
"""
90+
assert is_peft_model(model), "It's not a MO-PEFT model"
91+
92+
def matches_any_pattern(name, patterns, allow_callable=True):
93+
for pattern in patterns:
94+
if isinstance(pattern, str):
95+
if fnmatch.fnmatch(name, pattern):
96+
return True
97+
elif allow_callable and callable(pattern):
98+
if pattern(name):
99+
return True
100+
else:
101+
pattern_type = "pattern" if allow_callable else "adapter pattern"
102+
raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}")
103+
return False
104+
105+
for module_name, module in model.named_modules():
106+
if isinstance(module, LoRAModule):
107+
if layer_patterns is not None:
108+
if not matches_any_pattern(module_name, layer_patterns, allow_callable=True):
109+
continue
110+
111+
for adapter_name, adapter_dict in module._lora_adapters.items():
112+
if adapter_patterns is not None:
113+
if not matches_any_pattern(
114+
adapter_name, adapter_patterns, allow_callable=False
115+
):
116+
continue
117+
118+
adapter_dict["enable"] = enable_state
119+
120+
121+
def disable_adapters(model, layers_to_disable=None, adapters_to_disable=None):
122+
"""Disable LoRA adapters in the model.
123+
124+
Args:
125+
model: Model with LoRA adapters
126+
layers_to_disable: Optional list of layer name patterns (wildcards or callables)
127+
to disable adapters on. If None, disables on all layers.
128+
adapters_to_disable: Optional list of adapter name patterns (wildcards) to disable.
129+
If None, disables all adapters.
130+
131+
Examples:
132+
# Disable all adapters
133+
disable_adapters(model)
134+
135+
# Disable adapters only on attention layers
136+
disable_adapters(model, layers_to_disable=["*attention*"])
137+
138+
# Disable only "default" adapters
139+
disable_adapters(model, adapters_to_disable=["*default*"])
140+
141+
# Disable "default" adapters on attention layers only
142+
disable_adapters(model, layers_to_disable=["*attention*"], adapters_to_disable=["*default*"])
143+
"""
144+
_set_adapter_state(
145+
model,
146+
enable_state=False,
147+
layer_patterns=layers_to_disable,
148+
adapter_patterns=adapters_to_disable,
149+
)
150+
151+
152+
def enable_adapters(model, layers_to_enable=None, adapters_to_enable=None):
153+
"""Enable LoRA adapters in the model.
154+
155+
Args:
156+
model: Model with LoRA adapters
157+
layers_to_enable: Optional list of layer name patterns (wildcards or callables)
158+
to enable adapters on. If None, enables on all layers.
159+
adapters_to_enable: Optional list of adapter name patterns (wildcards) to enable.
160+
If None, enables all adapters.
161+
162+
Examples:
163+
# Enable all adapters
164+
enable_adapters(model)
165+
166+
# Enable adapters only on MLP layers
167+
enable_adapters(model, layers_to_enable=["*mlp*"])
168+
169+
# Enable only "finetuned" adapters
170+
enable_adapters(model, adapters_to_enable=["*finetuned*"])
171+
172+
# Enable "finetuned" adapters on MLP layers only
173+
enable_adapters(model, layers_to_enable=["*mlp*"], adapters_to_enable=["*finetuned*"])
174+
"""
175+
_set_adapter_state(
176+
model,
177+
enable_state=True,
178+
layer_patterns=layers_to_enable,
179+
adapter_patterns=adapters_to_enable,
180+
)
181+
182+
183+
def get_adapter_states(model):
184+
"""Get the current state of all adapters in the model.
185+
186+
Args:
187+
model: Model with LoRA adapters
188+
189+
Returns:
190+
Dict mapping module names to their adapter states
191+
192+
Example:
193+
>>> states = get_adapter_states(model)
194+
>>> print(states)
195+
{
196+
'transformer.layers.0.attention': {
197+
'default': {'enabled': True, 'rank': 32},
198+
'finetuned': {'enabled': False, 'rank': 64}
199+
},
200+
'transformer.layers.0.mlp': {
201+
'default': {'enabled': True, 'rank': 32}
202+
}
203+
}
204+
"""
205+
assert is_peft_model(model), "It's not a MO-PEFT model"
206+
207+
adapter_states = {}
208+
for module_name, module in model.named_modules():
209+
if isinstance(module, LoRAModule):
210+
module_adapters = {}
211+
for adapter_name, adapter_dict in module._lora_adapters.items():
212+
module_adapters[adapter_name] = {
213+
"enabled": adapter_dict.get("enable", True),
214+
"rank": adapter_dict.get("rank", "unknown"),
215+
"scale": adapter_dict.get("scale", 1.0),
216+
}
217+
if module_adapters:
218+
adapter_states[module_name] = module_adapters
219+
220+
return adapter_states

0 commit comments

Comments
 (0)