Skip to content

Commit ececc24

Browse files
committed
Add megatron lora support
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 43ed09f commit ececc24

File tree

10 files changed

+1201
-1
lines changed

10 files changed

+1201
-1
lines changed

modelopt/torch/peft/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Distillation API subpackage for torch."""
17+
18+
from . import mode
19+
from .config import *
20+
from .convert import *
21+
# isort: off
22+
# Import plugins last to avoid circular imports
23+
# from . import plugins

modelopt/torch/peft/config.py

Lines changed: 434 additions & 0 deletions
Large diffs are not rendered by default.

modelopt/torch/peft/conversion.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Quantization conversion/restore utilities."""
17+
18+
import fnmatch
19+
from collections.abc import Callable
20+
from contextlib import contextmanager
21+
from typing import Any
22+
23+
import torch.nn as nn
24+
25+
from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager
26+
from modelopt.torch.opt.dynamic import _DMRegistryCls
27+
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
28+
from modelopt.torch.utils import get_unwrapped_name
29+
30+
from .config import (
31+
PEFTConfig,
32+
_QuantizeExportConfig,
33+
)
34+
from .lora.layer import LoRAModuleRegistry
35+
36+
__all__ = [
37+
"replace_lora_module",
38+
]
39+
40+
41+
def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> ConvertReturnType:
42+
"""Convert the model to a quantized one as per `config`."""
43+
# initialize the true module if necessary
44+
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
45+
46+
# TODO: Replace to LoRA module
47+
replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
48+
# set_quantizer_by_cfg(model, config.get("quant_cfg", {}))
49+
50+
metadata = {}
51+
# update_quantize_metadata(model, config, metadata)
52+
53+
return model, metadata
54+
55+
def restore_peft_model(
56+
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict
57+
) -> nn.Module:
58+
#TODO: implemente the restore logic
59+
pass
60+
61+
62+
63+
def update_peft_metadata(
64+
model: nn.Module, config: PEFTConfig, metadata: MetadataDict
65+
) -> None:
66+
"""Update the quantizer state in the metadata dict."""
67+
pass
68+
69+
70+
def replace_lora_module(model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry):
71+
"""Recursively replace the module with quantized module."""
72+
#TODO: register the extra state for megatron-lm
73+
74+
if type(model) in registry:
75+
model = registry.convert(model)
76+
_replace_lora_module(model, version=version, registry=registry)
77+
78+
def export_peft_model(model: nn.Module, config):
79+
"""Export the quantized model to a quantized model."""
80+
raise NotImplementedError("Exporting a quantized model is not supported yet.")
81+
82+
83+
def restore_export_peft_model(
84+
model: nn.Module, config, metadata: MetadataDict
85+
):
86+
"""Restores the quantized model from the given state dict."""
87+
raise NotImplementedError("Restoring a quantized & exported model is not supported yet.")
88+
89+
90+
def _replace_lora_module(model: nn.Module, version=None,registry=LoRAModuleRegistry):
91+
for name, child in model.named_children():
92+
if type(child) in registry:
93+
lora_module = registry.convert(child)
94+
setattr(model, name, lora_module)
95+
96+
_replace_lora_module(getattr(model, name), version=version, registry=registry)
97+
98+
99+
def export_quantized_model(model: nn.Module, config: _QuantizeExportConfig) -> ConvertReturnType:
100+
"""Export the quantized model to a quantized model."""
101+
raise NotImplementedError("Exporting a quantized model is not supported yet.")
102+
103+
104+
def restore_export_quantized_model(
105+
model: nn.Module, config: _QuantizeExportConfig, metadata: MetadataDict
106+
) -> nn.Module:
107+
"""Restores the quantized model from the given state dict."""
108+
raise NotImplementedError("Restoring a quantized & exported model is not supported yet.")

modelopt/torch/peft/convert.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""User-facing quantization API."""
17+
18+
import fnmatch
19+
import inspect
20+
import warnings
21+
from collections.abc import Callable, Iterable
22+
from typing import Any
23+
24+
import torch
25+
import torch.nn as nn
26+
27+
# import modelopt.torch.quantization as mtq
28+
from modelopt.torch.opt import apply_mode
29+
# from modelopt.torch.opt.searcher import ForwardLoop
30+
# from modelopt.torch.opt.utils import forward_with_reshard
31+
from modelopt.torch.peft.config import PEFTConfig
32+
# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
33+
34+
# from . import config
35+
# from .algorithms import AutoQuantizeSearcher
36+
# from .config import QuantizeAlgoCfgType
37+
# from .conversion import set_quantizer_attribute
38+
from .mode import PEFTModeRegistry
39+
from .lora.layer import LoRAModule
40+
# from .nn import QuantModule, TensorQuantizer
41+
42+
# __all__ = [
43+
# "auto_quantize",
44+
# "calibrate",
45+
# "disable_quantizer",
46+
# "enable_quantizer",
47+
# "fold_weight",
48+
# "postprocess_amax",
49+
# "print_quant_summary",
50+
# "quantize",
51+
# ]
52+
53+
def update_model(
54+
model: nn.Module,
55+
config: dict[str, Any | PEFTConfig],
56+
):
57+
#TODO: deal with extra state, how to save the model
58+
#TODO: sharded dict
59+
#TODO: metadate
60+
#TODO: how to restore the model
61+
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
62+
return add_adapter(model, config)
63+
64+
def add_adapter(model, config):
65+
adapter_cfg = config["adapter_cfg"]
66+
adapter_name = config["adapter_name"]
67+
68+
for name, module in model.named_modules():
69+
if isinstance(module, LoRAModule):
70+
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
71+
if isinstance(wildcard_or_filter_func, str):
72+
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
73+
continue
74+
elif callable(wildcard_or_filter_func):
75+
if not wildcard_or_filter_func(name):
76+
continue
77+
else:
78+
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
79+
module.update_layer_lora(adapter_name, adapter_setting["rank"])
80+
return model
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import layer
2+
from . import tp_layer
3+
# from . import linear_layer

modelopt/torch/peft/lora/layer.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""LoRA (Low-Rank Adaptation) module implementation."""
2+
3+
from abc import abstractmethod
4+
from typing import Dict, Tuple, Any, Optional
5+
import torch
6+
import torch.nn as nn
7+
8+
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
9+
10+
__all__ = [
11+
"LoRAModule",
12+
"LoRAModuleRegistry",
13+
]
14+
15+
16+
class LoRAModule(DynamicModule):
17+
"""Base class for LoRA (Low-Rank Adaptation) modules.
18+
19+
This module wraps existing layers and adds trainable low-rank decomposition
20+
matrices (LoRA adapters) that are added to the original layer's output.
21+
22+
Attributes:
23+
_lora_adapters: Dictionary mapping adapter names to their LoRA A and B matrices
24+
_active_adapters: Set of currently active adapter names
25+
"""
26+
27+
def _setup(self) -> None:
28+
"""Initialize LoRA-specific attributes."""
29+
self._lora_adapters: Dict[str, Dict[str, nn.Module]] = {}
30+
self._active_adapters: set = set()
31+
32+
@property
33+
def adapter_names(self) -> set:
34+
"""Return the set of all registered adapter names."""
35+
return set(self._lora_adapters.keys())
36+
37+
@property
38+
def active_adapters(self) -> set:
39+
"""Return the set of currently active adapter names."""
40+
return self._active_adapters.copy()
41+
42+
def activate_adapter(self, adapter_name: str) -> None:
43+
"""Activate a specific adapter.
44+
45+
Args:
46+
adapter_name: Name of the adapter to activate
47+
48+
Raises:
49+
ValueError: If adapter_name is not registered
50+
"""
51+
if adapter_name not in self._lora_adapters:
52+
raise ValueError(f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}")
53+
self._active_adapters.add(adapter_name)
54+
55+
def deactivate_adapter(self, adapter_name: str) -> None:
56+
"""Deactivate a specific adapter.
57+
58+
Args:
59+
adapter_name: Name of the adapter to deactivate
60+
"""
61+
self._active_adapters.discard(adapter_name)
62+
63+
def activate_all_adapters(self) -> None:
64+
"""Activate all registered adapters."""
65+
self._active_adapters = self.adapter_names.copy()
66+
67+
def deactivate_all_adapters(self) -> None:
68+
"""Deactivate all adapters."""
69+
self._active_adapters.clear()
70+
71+
@abstractmethod
72+
def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None:
73+
"""Create and register a new LoRA adapter.
74+
75+
This method must be implemented by subclasses to create the appropriate
76+
LoRA A and B matrices for the specific layer type.
77+
78+
Args:
79+
adapter_name: Name for the new adapter
80+
rank: Rank of the LoRA decomposition (default: 64)
81+
"""
82+
raise NotImplementedError("Subclasses must implement update_layer_lora")
83+
84+
def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
85+
"""Forward pass with LoRA adaptation.
86+
87+
Args:
88+
x: Input tensor
89+
*args: Additional positional arguments for the base layer
90+
**kwargs: Additional keyword arguments for the base layer
91+
92+
Returns:
93+
Output from the base layer plus active LoRA adaptations
94+
"""
95+
# Call the base layer's forward method
96+
output = super().forward(x, *args, **kwargs)
97+
98+
# Handle different output types from base layer
99+
if isinstance(output, tuple):
100+
# If output is a tuple, assume first element is the main result
101+
result = output[0]
102+
other_outputs = output[1:]
103+
else:
104+
# If output is a single tensor
105+
result = output
106+
other_outputs = ()
107+
108+
# Apply active LoRA adapters
109+
if self._active_adapters and self._lora_adapters:
110+
for adapter_name in self._active_adapters:
111+
if adapter_name in self._lora_adapters:
112+
adapter = self._lora_adapters[adapter_name]
113+
# LoRA computation: result = result + B(A(x))
114+
lora_a = adapter['lora_a']
115+
lora_b = adapter['lora_b']
116+
117+
# Handle different forward signatures
118+
lora_a_output = lora_a(x)
119+
if isinstance(lora_a_output, tuple):
120+
lora_a_output = lora_a_output[0]
121+
122+
lora_b_output = lora_b(lora_a_output)
123+
if isinstance(lora_b_output, tuple):
124+
lora_b_output = lora_b_output[0]
125+
126+
result = result + lora_b_output
127+
128+
# Return output in the same format as the base layer
129+
if other_outputs:
130+
return (result,) + other_outputs
131+
else:
132+
return result
133+
134+
135+
LoRAModuleRegistry = _DMRegistryCls("LoRA", LoRAModule)

0 commit comments

Comments
 (0)