Skip to content

Commit ca9698d

Browse files
committed
Update more, config + conversation
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 9b006f9 commit ca9698d

File tree

5 files changed

+219
-111
lines changed

5 files changed

+219
-111
lines changed

modelopt/torch/peft/config.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,132 @@
1515

1616
"""Configuration classes for PEFT methods."""
1717

18+
import math
19+
from collections.abc import Callable
20+
from collections.abc import Callable as CallableType
21+
22+
import torch.nn.init as init
23+
from pydantic import field_validator, model_validator
24+
1825
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
1926

27+
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
28+
29+
30+
class PEFTAttributeConfig(ModeloptBaseConfig):
31+
"""Configuration for PEFT adapter attributes."""
32+
33+
enable: bool = ModeloptField(
34+
default=True,
35+
title="Enable adapter",
36+
description="If True, enables the adapter. If False, by-passes the adapter.",
37+
)
38+
39+
rank: int = ModeloptField(
40+
default=64,
41+
title="LoRA rank",
42+
description=(
43+
"The rank (dimension) of the LoRA matrices. "
44+
"Higher rank allows more expressiveness but uses more memory."
45+
),
46+
)
47+
48+
scale: float = ModeloptField(
49+
default=1.0,
50+
title="LoRA scaling factor",
51+
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.",
52+
)
53+
54+
lora_a_init: Callable[[object], None] | None = ModeloptField(
55+
default=lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)),
56+
title="LoRA A matrix initializer",
57+
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
58+
)
59+
60+
lora_b_init: Callable[[object], None] | None = ModeloptField(
61+
default=lambda weight: init.zeros_(weight),
62+
title="LoRA B matrix initializer",
63+
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
64+
)
65+
66+
@field_validator("rank")
67+
@classmethod
68+
def validate_rank(cls, v):
69+
"""Validate rank is positive."""
70+
if v < 1:
71+
raise ValueError("rank must be a positive integer")
72+
return v
73+
74+
@field_validator("scale")
75+
@classmethod
76+
def validate_scale(cls, v):
77+
"""Validate scale is positive."""
78+
if v <= 0:
79+
raise ValueError("scale must be a positive number")
80+
return v
81+
82+
@model_validator(mode="after")
83+
def validate_init_functions(self):
84+
"""Validate initialization functions are callable."""
85+
if self.lora_a_init is not None and not callable(self.lora_a_init):
86+
raise ValueError("lora_a_init must be callable")
87+
if self.lora_b_init is not None and not callable(self.lora_b_init):
88+
raise ValueError("lora_b_init must be callable")
89+
return self
90+
91+
92+
# Type alias for adapter configuration
93+
PEFTAdapterCfgType = dict[str | CallableType, PEFTAttributeConfig | dict]
94+
2095

2196
class PEFTConfig(ModeloptBaseConfig):
2297
"""Default configuration for ``peft`` mode."""
2398

2499
adapter_name: str = ModeloptField(
25100
default="default",
26-
title="Placeholder",
101+
title="Adapter name",
102+
description="Name of the adapter to create or update.",
27103
validate_default=True,
28104
)
29105

30-
adapter_cfg: dict = ModeloptField(
106+
adapter_cfg: PEFTAdapterCfgType = ModeloptField(
31107
default={"default": {"rank": 128}},
32-
title="Placeholder",
108+
title="Adapter configuration",
109+
description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.",
33110
validate_default=True,
34111
)
35112

36113
adapter_type: str = ModeloptField(
37114
default="lora",
38-
title="Placeholder",
115+
title="Adapter type",
116+
description="Type of PEFT adapter to use. Currently only 'lora' is supported.",
39117
validate_default=True,
40118
)
41119

120+
@field_validator("adapter_type")
121+
@classmethod
122+
def validate_adapter_type(cls, v):
123+
"""Validate adapter type."""
124+
if v not in ["lora"]:
125+
raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.")
126+
return v
127+
128+
@field_validator("adapter_cfg")
129+
@classmethod
130+
def validate_adapter_cfg(cls, v):
131+
"""Validate and convert adapter configurations."""
132+
validated_cfg = {}
133+
for key, value in v.items():
134+
if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig):
135+
# Convert dict to PEFTAttributeConfig to trigger validation
136+
try:
137+
validated_cfg[key] = PEFTAttributeConfig(**value)
138+
except Exception as e:
139+
raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
140+
else:
141+
validated_cfg[key] = value
142+
return validated_cfg
143+
42144

43145
class ExportPEFTConfig(ModeloptBaseConfig):
44146
"""An empty config."""

modelopt/torch/peft/conversion.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""PEFT conversion and restore utilities for LoRA modules."""
1717

18+
import fnmatch
1819
from typing import Any
1920

2021
import torch.nn as nn
@@ -41,6 +42,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4142
replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
4243

4344
metadata = {}
45+
add_adapter(model, config)
4446
# Should return adapaters, active_adapters
4547
update_peft_metadata(model, config, metadata)
4648

@@ -157,3 +159,60 @@ def update_peft_metadata_in_model(model: nn.Module) -> None:
157159
# Update the metadata with current PEFT state
158160
if manager._state and manager._last_metadata is not None:
159161
manager._last_metadata["peft_state"] = peft_state(model)
162+
163+
164+
def add_adapter(model, config: PEFTConfig):
165+
"""Add a new LoRA adapter to the model.
166+
167+
Args:
168+
model: Model with LoRA modules to add adapters to
169+
config: PEFTConfig instance containing adapter_cfg and adapter_name
170+
171+
Returns:
172+
The model with the new adapter added
173+
"""
174+
adapter_cfg = config.adapter_cfg
175+
adapter_name = config.adapter_name
176+
177+
for name, module in model.named_modules():
178+
if isinstance(module, LoRAModule):
179+
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
180+
if isinstance(wildcard_or_filter_func, str):
181+
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
182+
continue
183+
elif callable(wildcard_or_filter_func):
184+
if not wildcard_or_filter_func(name):
185+
continue
186+
else:
187+
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
188+
if adapter_setting.enable: # type: ignore[union-attr]
189+
module.update_layer_lora(
190+
adapter_name,
191+
adapter_setting,
192+
)
193+
194+
_update_peft_metadata_in_state(model)
195+
return model
196+
197+
198+
def _update_peft_metadata_in_state(model: nn.Module) -> None:
199+
"""Update the PEFT metadata in the ModeloptStateManager.
200+
201+
This function updates the metadata to reflect the current state of LoRA adapters
202+
after they have been added or modified.
203+
"""
204+
if not ModeloptStateManager.is_converted(model):
205+
return
206+
207+
manager = ModeloptStateManager(model)
208+
209+
current_peft_state = {}
210+
for name, module in model.named_modules():
211+
if isinstance(module, LoRAModule):
212+
from modelopt.torch.utils import get_unwrapped_name
213+
214+
unwrapped_name = get_unwrapped_name(name)
215+
current_peft_state[unwrapped_name] = module.get_peft_state()
216+
217+
if manager._state and manager._last_metadata is not None:
218+
manager._last_metadata["peft_state"] = current_peft_state

modelopt/torch/peft/convert.py

Lines changed: 9 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,21 @@
1515

1616
"""User-facing PEFT API for LoRA module conversion and adapter management."""
1717

18-
import fnmatch
1918
from typing import Any
2019

2120
import torch.nn as nn
2221

2322
from modelopt.torch.opt import apply_mode
24-
from modelopt.torch.opt.conversion import ModeloptStateManager
2523
from modelopt.torch.peft.config import PEFTConfig
24+
from modelopt.torch.peft.conversion import add_adapter
2625

2726
from .lora.layer import LoRAModule
2827
from .mode import PEFTModeRegistry
2928

3029

3130
def update_model(
3231
model: nn.Module,
33-
config: dict[str, Any | PEFTConfig],
32+
config: dict[str, Any] | PEFTConfig,
3433
):
3534
"""Update model with PEFT/LoRA adapters.
3635
@@ -40,78 +39,24 @@ def update_model(
4039
4140
Args:
4241
model: The model to update
43-
config: PEFT configuration containing adapter settings
42+
config: PEFT configuration dict or PEFTConfig instance
4443
4544
Returns:
4645
The updated model with LoRA adapters
4746
"""
47+
# Validate config by converting to PEFTConfig if needed
48+
4849
# Check if model is already in PEFT mode by looking for LoRA modules
4950
if not is_peft_model(model):
5051
# First time - need to convert to PEFT mode
5152
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
52-
return add_adapter(model, config)
53-
54-
55-
def add_adapter(model, config):
56-
"""Add a new LoRA adapter to the model.
57-
58-
Args:
59-
model: Model with LoRA modules to add adapters to
60-
config: Configuration dict containing adapter_cfg and adapter_name
61-
62-
Returns:
63-
The model with the new adapter added
64-
"""
65-
adapter_cfg = config["adapter_cfg"]
66-
adapter_name = config["adapter_name"]
67-
68-
for name, module in model.named_modules():
69-
if isinstance(module, LoRAModule):
70-
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
71-
if isinstance(wildcard_or_filter_func, str):
72-
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
73-
continue
74-
elif callable(wildcard_or_filter_func):
75-
if not wildcard_or_filter_func(name):
76-
continue
77-
else:
78-
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
79-
module.update_layer_lora(
80-
adapter_name, adapter_setting["rank"], adapter_setting.get("scale", 1.0)
81-
)
82-
83-
# Update the metadata in ModeloptStateManager after adding adapters
84-
_update_peft_metadata_in_state(model)
53+
else:
54+
if not isinstance(config, PEFTConfig):
55+
config = PEFTConfig(**config)
56+
add_adapter(model, config)
8557
return model
8658

8759

88-
def _update_peft_metadata_in_state(model: nn.Module) -> None:
89-
"""Update the PEFT metadata in the ModeloptStateManager.
90-
91-
This function updates the metadata to reflect the current state of LoRA adapters
92-
after they have been added or modified.
93-
"""
94-
# Check if model has ModeloptStateManager (has been converted with peft mode)
95-
if not ModeloptStateManager.is_converted(model):
96-
return
97-
98-
# Get the state manager
99-
manager = ModeloptStateManager(model)
100-
101-
# Get current PEFT state from all LoRA modules
102-
current_peft_state = {}
103-
for name, module in model.named_modules():
104-
if isinstance(module, LoRAModule):
105-
from modelopt.torch.utils import get_unwrapped_name
106-
107-
unwrapped_name = get_unwrapped_name(name)
108-
current_peft_state[unwrapped_name] = module.get_peft_state()
109-
110-
# Update the metadata in the last mode state (which should be 'peft')
111-
if manager._state and manager._last_metadata is not None:
112-
manager._last_metadata["peft_state"] = current_peft_state
113-
114-
11560
def is_peft_model(model: nn.Module) -> bool:
11661
"""Check if the model has been converted to PEFT/LoRA model.
11762

modelopt/torch/peft/lora/layer.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
1111

12+
from ..config import PEFTAttributeConfig
13+
1214
__all__ = [
1315
"LoRAModule",
1416
"LoRAModuleRegistry",
@@ -100,7 +102,11 @@ def _register_adapter(
100102
self.activate_adapter(adapter_name)
101103

102104
@abstractmethod
103-
def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1.0) -> None:
105+
def update_layer_lora(
106+
self,
107+
adapter_name: str,
108+
attr_config: PEFTAttributeConfig,
109+
) -> None:
104110
"""Create and register a new LoRA adapter.
105111
106112
This method must be implemented by subclasses to create the appropriate
@@ -110,6 +116,8 @@ def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1.
110116
adapter_name: Name for the new adapter
111117
rank: Rank of the LoRA decomposition (default: 64)
112118
scale: Scale factor for the LoRA output (default: 1.0)
119+
lora_a_init: Optional initialization function for LoRA A matrix
120+
lora_b_init: Optional initialization function for LoRA B matrix
113121
"""
114122
raise NotImplementedError("Subclasses must implement update_layer_lora")
115123

@@ -189,24 +197,17 @@ def set_from_peft_state(self, peft_state: dict[str, Any]) -> None:
189197
"""
190198
adapters_config = peft_state.get("adapters", {})
191199

192-
# Clear existing adapters first
193200
self._lora_adapters.clear()
194201
self._active_adapters.clear()
195202

196-
# Recreate each adapter based on saved configuration
197203
for adapter_name, config in adapters_config.items():
198-
rank = config.get("rank")
199-
scale = config.get("scale", 1.0)
200-
201-
if rank is not None:
202-
# Create the adapter with saved configuration
203-
self.update_layer_lora(adapter_name, rank=rank, scale=scale)
204+
self.update_layer_lora(adapter_name, config)
204205

205-
# Set activation state
206-
if config.get("is_active", False):
207-
self.activate_adapter(adapter_name)
208-
else:
209-
self.deactivate_adapter(adapter_name)
206+
# Set activation state
207+
if config.get("is_active", False):
208+
self.activate_adapter(adapter_name)
209+
else:
210+
self.deactivate_adapter(adapter_name)
210211

211212
def set_extra_state(self, state: dict[str, Any]) -> None:
212213
"""Restore extra state for distributed checkpointing.
@@ -281,7 +282,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
281282

282283
# Return output in the same format as the base layer
283284
if other_outputs:
284-
return (result,) + other_outputs
285+
return (result, *other_outputs)
285286
else:
286287
return result
287288

0 commit comments

Comments
 (0)