Skip to content

Commit 63e64c5

Browse files
committed
Add more functions
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 3c17447 commit 63e64c5

File tree

7 files changed

+388
-91
lines changed

7 files changed

+388
-91
lines changed

modelopt/torch/peft/config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from collections.abc import Callable
1716
from typing import Literal
1817

1918
from pydantic import ValidationInfo, field_validator, model_validator
2019

2120
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2221
from modelopt.torch.utils.network import ConstructorLike
22+
2323
BiasType = Literal["static", "dynamic"]
2424
BiasMethod = Literal["mean", "max_min"]
2525

26+
2627
class QuantizerAttributeConfig(ModeloptBaseConfig):
2728
"""Quantizer attribute type."""
2829

@@ -358,9 +359,10 @@ class SVDQuantConfig(QuantizeAlgorithmConfig):
358359

359360
# QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None
360361

361-
#TODO Jingyu Xin
362+
363+
# TODO Jingyu Xin
362364
class PEFTConfig(ModeloptBaseConfig):
363-
"""Default configuration for ``quantize`` mode."""
365+
"""Default configuration for ``peft`` mode."""
364366

365367
adapter_name: str = ModeloptField(
366368
default="default",
@@ -380,8 +382,11 @@ class PEFTConfig(ModeloptBaseConfig):
380382
validate_default=True,
381383
)
382384

385+
383386
class ExportPEFTConfig(ModeloptBaseConfig):
384387
"""An empty config."""
388+
389+
385390
class CompressConfig(ModeloptBaseConfig):
386391
"""Default configuration for ``compress`` mode."""
387392

modelopt/torch/peft/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4242
# set_quantizer_by_cfg(model, config.get("quant_cfg", {}))
4343

4444
metadata = {}
45+
# Should return adapaters, active_adapters
4546
update_peft_metadata(model, config, metadata)
4647

4748
return model, metadata

modelopt/torch/peft/convert.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,47 +20,32 @@
2020

2121
import torch.nn as nn
2222

23-
# import modelopt.torch.quantization as mtq
2423
from modelopt.torch.opt import apply_mode
25-
26-
# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
2724
from modelopt.torch.opt.conversion import ModeloptStateManager
28-
29-
# from modelopt.torch.opt.searcher import ForwardLoop
30-
# from modelopt.torch.opt.utils import forward_with_reshard
3125
from modelopt.torch.peft.config import PEFTConfig
3226

3327
from .lora.layer import LoRAModule
34-
35-
# from . import config
36-
# from .algorithms import AutoQuantizeSearcher
37-
# from .config import QuantizeAlgoCfgType
38-
# from .conversion import set_quantizer_attribute
3928
from .mode import PEFTModeRegistry
4029

41-
# from .nn import QuantModule, TensorQuantizer
42-
43-
# __all__ = [
44-
# "auto_quantize",
45-
# "calibrate",
46-
# "disable_quantizer",
47-
# "enable_quantizer",
48-
# "fold_weight",
49-
# "postprocess_amax",
50-
# "print_quant_summary",
51-
# "quantize",
52-
# ]
53-
5430

5531
def update_model(
5632
model: nn.Module,
5733
config: dict[str, Any | PEFTConfig],
5834
):
59-
# TODO: deal with extra state, how to save the model
60-
# TODO: sharded dict
61-
# TODO: metadate
62-
# TODO: how to restore the model
63-
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
35+
"""Update model with PEFT/LoRA adapters.
36+
This function handles both initial PEFT conversion and adding additional adapters:
37+
- First call: Converts modules to LoRAModules and adds the first adapter
38+
- Subsequent calls: Adds new adapters to existing LoRAModules
39+
Args:
40+
model: The model to update
41+
config: PEFT configuration containing adapter settings
42+
Returns:
43+
The updated model with LoRA adapters
44+
"""
45+
# Check if model is already in PEFT mode by looking for LoRA modules
46+
if not is_peft_model(model):
47+
# First time - need to convert to PEFT mode
48+
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
6449
return add_adapter(model, config)
6550

6651

@@ -79,7 +64,9 @@ def add_adapter(model, config):
7964
continue
8065
else:
8166
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
82-
module.update_layer_lora(adapter_name, adapter_setting["rank"])
67+
module.update_layer_lora(
68+
adapter_name, adapter_setting["rank"], adapter_setting.get("scale", 1.0)
69+
)
8370

8471
# Update the metadata in ModeloptStateManager after adding adapters
8572
_update_peft_metadata_in_state(model)
@@ -111,3 +98,21 @@ def _update_peft_metadata_in_state(model: nn.Module) -> None:
11198
# Update the metadata in the last mode state (which should be 'peft')
11299
if manager._state and manager._last_metadata is not None:
113100
manager._last_metadata["peft_state"] = current_peft_state
101+
102+
103+
def is_peft_model(model: nn.Module) -> bool:
104+
"""Check if the model has been converted to PEFT/LoRA model.
105+
106+
This function checks if any modules in the model are LoRAModule instances,
107+
which indicates the model has already been converted to PEFT mode.
108+
109+
Args:
110+
model: The model to check
111+
112+
Returns:
113+
True if the model contains LoRA modules, False otherwise
114+
"""
115+
for _, module in model.named_modules():
116+
if isinstance(module, LoRAModule):
117+
return True
118+
return False

modelopt/torch/peft/lora/layer.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def deactivate_all_adapters(self) -> None:
7373
self._active_adapters.clear()
7474

7575
def _register_adapter(
76-
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int
76+
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float = 1.0
7777
) -> None:
7878
"""Register a new LoRA adapter with explicit rank tracking.
7979
@@ -82,6 +82,7 @@ def _register_adapter(
8282
lora_a: LoRA A module (down-projection)
8383
lora_b: LoRA B module (up-projection)
8484
rank: Rank of the LoRA decomposition
85+
scale: Scale factor for the LoRA output
8586
"""
8687
# Add as submodules for proper parameter registration
8788
self.add_module(f"lora_a_{adapter_name}", lora_a)
@@ -92,13 +93,14 @@ def _register_adapter(
9293
"lora_a": lora_a,
9394
"lora_b": lora_b,
9495
"rank": rank, # Store rank explicitly for reliability
96+
"scale": scale,
9597
}
9698

9799
# Automatically activate new adapters
98100
self.activate_adapter(adapter_name)
99101

100102
@abstractmethod
101-
def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None:
103+
def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1.0) -> None:
102104
"""Create and register a new LoRA adapter.
103105
104106
This method must be implemented by subclasses to create the appropriate
@@ -107,6 +109,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None:
107109
Args:
108110
adapter_name: Name for the new adapter
109111
rank: Rank of the LoRA decomposition (default: 64)
112+
scale: Scale factor for the LoRA output (default: 1.0)
110113
"""
111114
raise NotImplementedError("Subclasses must implement update_layer_lora")
112115

@@ -148,14 +151,12 @@ def get_peft_state(self) -> dict[str, Any]:
148151
"is_active": adapter_name in self._active_adapters,
149152
"lora_a_type": type(lora_a).__name__,
150153
"lora_b_type": type(lora_b).__name__,
154+
"scale": adapter_modules.get("scale", 1.0),
151155
}
152156

153157
modelopt_state["adapters"] = adapters_config
154158
modelopt_state["active_adapters"] = list(self._active_adapters)
155159

156-
# Store the base module type for validation
157-
modelopt_state["base_module_type"] = type(self).__name__
158-
159160
return modelopt_state
160161

161162
def get_extra_state(self) -> dict[str, Any]:
@@ -177,6 +178,36 @@ def get_extra_state(self) -> dict[str, Any]:
177178

178179
return {"modelopt_peft_state": peft_state}
179180

181+
def set_from_peft_state(self, peft_state: dict[str, Any]) -> None:
182+
"""Restore LoRA adapters from saved PEFT state.
183+
184+
This method recreates LoRA adapters based on their saved configuration.
185+
Note: This only restores the adapter structure, not the weights.
186+
187+
Args:
188+
peft_state: Dictionary containing adapter configurations
189+
"""
190+
adapters_config = peft_state.get("adapters", {})
191+
192+
# Clear existing adapters first
193+
self._lora_adapters.clear()
194+
self._active_adapters.clear()
195+
196+
# Recreate each adapter based on saved configuration
197+
for adapter_name, config in adapters_config.items():
198+
rank = config.get("rank")
199+
scale = config.get("scale", 1.0)
200+
201+
if rank is not None:
202+
# Create the adapter with saved configuration
203+
self.update_layer_lora(adapter_name, rank=rank, scale=scale)
204+
205+
# Set activation state
206+
if config.get("is_active", False):
207+
self.activate_adapter(adapter_name)
208+
else:
209+
self.deactivate_adapter(adapter_name)
210+
180211
def set_extra_state(self, state: dict[str, Any]) -> None:
181212
"""Restore extra state for distributed checkpointing.
182213
@@ -245,7 +276,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
245276
if isinstance(lora_b_output, tuple):
246277
lora_b_output = lora_b_output[0]
247278

248-
result = result + lora_b_output
279+
scale = adapter.get("scale", 1.0)
280+
result = result + scale * lora_b_output
249281

250282
# Return output in the same format as the base layer
251283
if other_outputs:

modelopt/torch/peft/lora/tp_layer.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,20 @@
99

1010
from .layer import LoRAModule, LoRAModuleRegistry
1111

12-
# Default rank for LoRA decomposition
12+
try:
13+
from modelopt.torch.quantization.plugins.megatron import (
14+
_MegatronColumnParallelLinear as QuantColumnParallelLinear,
15+
)
16+
from modelopt.torch.quantization.plugins.megatron import (
17+
_MegatronRowParallelLinear as QuantRowParallelLinear,
18+
)
19+
20+
QUANT_MODULES_AVAILABLE = True
21+
except ImportError:
22+
QUANT_MODULES_AVAILABLE = False
23+
1324
DEFAULT_LORA_RANK = 64
25+
DEFAULT_SCALE = 1.0
1426

1527

1628
class _MegatronParallelLoRABase(LoRAModule):
@@ -33,7 +45,7 @@ def _get_init_methods(self) -> tuple[Callable, Callable]:
3345
return lora_a_init, lora_b_init
3446

3547
def _register_adapter_with_device(
36-
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int
48+
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float
3749
) -> None:
3850
"""Register LoRA adapter modules and ensure correct device placement.
3951
@@ -43,23 +55,29 @@ def _register_adapter_with_device(
4355
lora_b: LoRA B module (up-projection)
4456
rank: Rank of the LoRA decomposition
4557
"""
46-
# Move LoRA modules to the same device as the parent module
47-
# Try to get device from parent module's parameters or buffers
58+
# Move LoRA modules to the same device and dtype as the parent module
59+
# Try to get device and dtype from parent module's parameters or buffers
4860
device = None
61+
dtype = None
4962
for p in self.parameters():
5063
device = p.device
64+
dtype = p.dtype
5165
break
5266
if device is None:
5367
for b in self.buffers():
5468
device = b.device
69+
dtype = b.dtype
5570
break
5671

57-
# If we found a device, move LoRA modules to it
72+
# If we found a device and dtype, move LoRA modules to match
5873
if device is not None:
5974
lora_a = lora_a.to(device)
6075
lora_b = lora_b.to(device)
76+
if dtype is not None:
77+
lora_a = lora_a.to(dtype)
78+
lora_b = lora_b.to(dtype)
6179

62-
super()._register_adapter(adapter_name, lora_a, lora_b, rank)
80+
super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale)
6381

6482

6583
@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"})
@@ -70,7 +88,9 @@ class _MegatronColumnParallelLinear(_MegatronParallelLoRABase):
7088
the parallelization scheme of the base layer.
7189
"""
7290

73-
def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None:
91+
def update_layer_lora(
92+
self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE
93+
) -> None:
7494
"""Create and register a new LoRA adapter for ColumnParallelLinear.
7595
7696
Args:
@@ -100,7 +120,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) ->
100120
init_method=lora_b_init,
101121
)
102122

103-
self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank)
123+
self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale)
104124

105125

106126
@LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"})
@@ -111,7 +131,9 @@ class _MegatronRowParallelLinear(_MegatronParallelLoRABase):
111131
the parallelization scheme of the base layer.
112132
"""
113133

114-
def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None:
134+
def update_layer_lora(
135+
self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE
136+
) -> None:
115137
"""Create and register a new LoRA adapter for RowParallelLinear.
116138
117139
Args:
@@ -141,4 +163,15 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) ->
141163
init_method=lora_b_init,
142164
)
143165

144-
self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank)
166+
self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale)
167+
168+
169+
# Register quantized versions if available
170+
if QUANT_MODULES_AVAILABLE:
171+
# Register the same LoRA implementations for quantized modules
172+
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
173+
_MegatronColumnParallelLinear
174+
)
175+
LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})(
176+
_MegatronRowParallelLinear
177+
)

0 commit comments

Comments
 (0)