Skip to content

Commit 39f81e5

Browse files
[1/N] ModelOPT PEFT mode support for the megatron-lm (#342)
Signed-off-by: Jingyu Xin <[email protected]> Signed-off-by: jingyu-ml <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent 557633c commit 39f81e5

File tree

17 files changed

+2350
-1
lines changed

17 files changed

+2350
-1
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ modelopt/torch/distill @NVIDIA/modelopt-torch-distill-codeowners
2222
modelopt/torch/export @NVIDIA/modelopt-torch-export-codeowners
2323
modelopt/torch/nas @NVIDIA/modelopt-torch-nas-prune-codeowners
2424
modelopt/torch/opt @NVIDIA/modelopt-torch-opt-codeowners
25+
modelopt/torch/peft @NVIDIA/modelopt-torch-peft-codeowners
2526
modelopt/torch/prune @NVIDIA/modelopt-torch-nas-prune-codeowners
2627
modelopt/torch/quantization @NVIDIA/modelopt-torch-quantization-codeowners
2728
modelopt/torch/sparsity @NVIDIA/modelopt-torch-sparsity-codeowners

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Model Optimizer Changelog (Linux)
99
**New Features**
1010

1111
- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
12+
- Add LoRA mode support for MCore in a new peft submodule: ``modelopt.torch.peft.update_model(model, LORA_CFG)``.
1213
- Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details.
1314

1415
0.37 (2025-09-xx)

modelopt/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from packaging.version import Version as _Version
2121
from torch import __version__ as _torch_version
2222

23-
from . import distill, nas, opt, prune, quantization, sparsity, speculative, utils
23+
from . import distill, nas, opt, peft, prune, quantization, sparsity, speculative, utils
2424

2525
if _Version(_torch_version) < _Version("2.7"):
2626
_warnings.warn(

modelopt/torch/peft/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""PEFT API subpackage for torch."""
17+
18+
from . import mode
19+
from .config import *
20+
from .conversion import *
21+
from .convert import *

modelopt/torch/peft/config.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Configuration classes for PEFT methods."""
17+
18+
import importlib
19+
import inspect
20+
from collections.abc import Callable
21+
from typing import Annotated, Any
22+
23+
import torch.nn.init as init
24+
from pydantic import PlainSerializer, WithJsonSchema, field_validator
25+
26+
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
27+
28+
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
29+
30+
InitRuntimeType = Any
31+
32+
33+
def _qualname(fn) -> str:
34+
m = inspect.getmodule(fn)
35+
return f"{m.__name__}.{fn.__name__}" if m else getattr(fn, "__name__", str(fn))
36+
37+
38+
InitField = Annotated[
39+
InitRuntimeType,
40+
WithJsonSchema(
41+
{
42+
"type": "string",
43+
"title": "torch initializer",
44+
"description": (
45+
"Fully-qualified callable from ``torch.nn.init``. "
46+
"Must be in-place (name ends with ``\\_``)."
47+
),
48+
"examples": ["torch.nn.init.zeros\\_", "torch.nn.init.kaiming_uniform\\_"],
49+
}
50+
),
51+
PlainSerializer(lambda v: _qualname(v), return_type=str, when_used="always"),
52+
]
53+
54+
55+
class PEFTAttributeConfig(ModeloptBaseConfig):
56+
"""Configuration for PEFT adapter attributes."""
57+
58+
enable: bool = ModeloptField(
59+
default=True,
60+
title="Enable adapter",
61+
description="If True, enables the adapter. If False, by-passes the adapter.",
62+
)
63+
64+
rank: int = ModeloptField(
65+
default=64,
66+
title="LoRA rank",
67+
description=(
68+
"The rank (dimension) of the LoRA matrices. "
69+
"Higher rank allows more expressiveness but uses more memory."
70+
),
71+
)
72+
73+
scale: float = ModeloptField(
74+
default=1.0,
75+
title="LoRA scaling factor",
76+
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.",
77+
)
78+
79+
lora_a_init: InitField = ModeloptField(
80+
default=init.kaiming_uniform_,
81+
title="LoRA A matrix initializer",
82+
description="Initializer from ``torch.nn.init`` (in-place; name ends with ``\\_``).",
83+
)
84+
85+
lora_b_init: InitField = ModeloptField(
86+
default=init.zeros_,
87+
title="LoRA B matrix initializer",
88+
description="Initializer from ``torch.nn.init`` (in-place; name ends with ``\\_``).",
89+
)
90+
91+
@field_validator("lora_a_init", "lora_b_init", mode="before")
92+
@classmethod
93+
def _parse_init_callable(cls, v):
94+
if isinstance(v, str):
95+
try:
96+
module_path, func_name = v.rsplit(".", 1)
97+
mod = importlib.import_module(module_path)
98+
v = getattr(mod, func_name)
99+
except Exception as e:
100+
raise ValueError(
101+
f"Could not resolve initializer '{v}' into a callable "
102+
"(expected a dotted path like 'torch.nn.init.zeros_')."
103+
) from e
104+
return v
105+
106+
@field_validator("lora_a_init", "lora_b_init")
107+
@classmethod
108+
def validate_init_method(cls, v):
109+
"""Validate initialization method is supported."""
110+
if callable(v):
111+
module = inspect.getmodule(v)
112+
if module is not init:
113+
raise ValueError(
114+
"Callable initialization method must be from torch.nn.init, "
115+
f"got {module.__name__ if module else 'unknown'}"
116+
)
117+
func_name = getattr(v, "__name__", "")
118+
if not func_name.endswith("_"):
119+
raise ValueError(
120+
"Initialization method must be in-place (name ends with '_'). "
121+
"For example: ``torch.nn.init.kaiming_uniform\\_`` not "
122+
"``torch.nn.init.kaiming_uniform``."
123+
)
124+
else:
125+
raise ValueError(
126+
f"Initialization method must be a callable function from torch.nn.init, got {type(v)}"
127+
)
128+
return v
129+
130+
@field_validator("rank")
131+
@classmethod
132+
def validate_rank(cls, v):
133+
"""Validate rank is positive."""
134+
if v < 1:
135+
raise ValueError("rank must be a positive integer")
136+
return v
137+
138+
@field_validator("scale")
139+
@classmethod
140+
def validate_scale(cls, v):
141+
"""Validate scale is positive."""
142+
if v <= 0:
143+
raise ValueError("scale must be a positive number")
144+
return v
145+
146+
147+
# Type alias for adapter configuration
148+
PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]
149+
150+
151+
class PEFTConfig(ModeloptBaseConfig):
152+
"""Default configuration for ``peft`` mode.
153+
154+
For adapter_cfg, later patterns override earlier ones, for example::
155+
156+
"adapter_cfg": {
157+
"*": {
158+
"rank": 32,
159+
"scale": 1,
160+
"enable": True,
161+
},
162+
"*output_layer*": {"enable": False},
163+
}
164+
165+
If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``.
166+
"""
167+
168+
adapter_name: str = ModeloptField(
169+
default="default",
170+
title="Adapter name",
171+
description="Name of the adapter to create or update.",
172+
validate_default=True,
173+
)
174+
175+
adapter_cfg: PEFTAdapterCfgType = ModeloptField(
176+
default={"*": {"rank": 64}},
177+
title="Adapter configuration",
178+
description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.",
179+
validate_default=True,
180+
)
181+
182+
adapter_type: str = ModeloptField(
183+
default="lora",
184+
title="Adapter type",
185+
description="Type of PEFT adapter to use. Currently only 'lora' is supported.",
186+
validate_default=True,
187+
)
188+
189+
freeze_base_model: bool = ModeloptField(
190+
default=True,
191+
title="Freeze base weights during training",
192+
description="Whether to freeze the base model weights; in most cases, this should be set to True.",
193+
validate_default=True,
194+
)
195+
196+
freeze_lora_weights: bool = ModeloptField(
197+
default=False,
198+
title="Freeze lora weights during training",
199+
description="Whether to freeze the lora model weights; in most cases, this should be set to False.",
200+
validate_default=True,
201+
)
202+
203+
@field_validator("adapter_type")
204+
@classmethod
205+
def validate_adapter_type(cls, v):
206+
"""Validate adapter type."""
207+
if v not in ["lora"]:
208+
raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.")
209+
return v
210+
211+
@field_validator("adapter_cfg")
212+
@classmethod
213+
def validate_adapter_cfg(cls, v):
214+
"""Validate and convert adapter configurations."""
215+
validated_cfg = {}
216+
for key, value in v.items():
217+
if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig):
218+
# Convert dict to PEFTAttributeConfig to trigger validation
219+
try:
220+
validated_cfg[key] = PEFTAttributeConfig(**value)
221+
except Exception as e:
222+
raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
223+
else:
224+
validated_cfg[key] = value
225+
return validated_cfg
226+
227+
228+
class ExportPEFTConfig(ModeloptBaseConfig):
229+
"""An empty config."""

0 commit comments

Comments
 (0)