Skip to content

Commit e70abb3

Browse files
committed
Update
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 5e9f7e6 commit e70abb3

File tree

12 files changed

+1078
-174
lines changed

12 files changed

+1078
-174
lines changed

modelopt/torch/peft/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import mode
1919
from .config import *
2020
from .convert import *
21+
2122
# isort: off
2223
# Import plugins last to avoid circular imports
23-
# from . import plugins
24+
from . import plugins

modelopt/torch/peft/conversion.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,20 @@
1515

1616
"""Quantization conversion/restore utilities."""
1717

18-
import fnmatch
19-
from collections.abc import Callable
20-
from contextlib import contextmanager
2118
from typing import Any
2219

2320
import torch.nn as nn
2421

2522
from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager
26-
from modelopt.torch.opt.dynamic import _DMRegistryCls
2723
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
2824
from modelopt.torch.utils import get_unwrapped_name
2925

30-
from .config import (
31-
PEFTConfig,
32-
_QuantizeExportConfig,
33-
)
34-
from .lora.layer import LoRAModuleRegistry
26+
from .config import PEFTConfig, _QuantizeExportConfig
27+
from .lora.layer import LoRAModule, LoRAModuleRegistry
3528

3629
__all__ = [
3730
"replace_lora_module",
31+
"update_peft_metadata_in_model",
3832
]
3933

4034

@@ -48,46 +42,88 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4842
# set_quantizer_by_cfg(model, config.get("quant_cfg", {}))
4943

5044
metadata = {}
51-
# update_quantize_metadata(model, config, metadata)
45+
update_peft_metadata(model, config, metadata)
5246

5347
return model, metadata
5448

49+
5550
def restore_peft_model(
5651
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict
5752
) -> nn.Module:
58-
#TODO: implemente the restore logic
59-
pass
60-
61-
62-
63-
def update_peft_metadata(
64-
model: nn.Module, config: PEFTConfig, metadata: MetadataDict
65-
) -> None:
66-
"""Update the quantizer state in the metadata dict."""
67-
pass
68-
53+
convert_to_peft_model(model, config)
54+
return restore_peft_state(model, metadata)
55+
56+
57+
def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict):
58+
"""Restore PEFT state from metadata or extra_state.
59+
For backward compatibility, we check metadata first. For distributed
60+
checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule
61+
and will be restored automatically via set_extra_state() during load_state_dict().
62+
63+
Args:
64+
model: Model with LoRA modules to restore
65+
metadata: Metadata dictionary that may contain peft_state
66+
Returns:
67+
The model with restored PEFT state
68+
"""
69+
if "peft_state" not in metadata:
70+
# For distributed checkpoints (NeMo-MCore), peft_state is stored
71+
# in each LoRAModule's extra_state and will be restored via
72+
# set_extra_state() during load_state_dict()
73+
return model
74+
75+
# Legacy path: restore from metadata
76+
peft_state_dict = metadata["peft_state"]
77+
for name, module in model.named_modules():
78+
if isinstance(module, LoRAModule):
79+
unwrapped_name = get_unwrapped_name(name)
80+
if unwrapped_name in peft_state_dict:
81+
try:
82+
module.set_from_peft_state(peft_state_dict[unwrapped_name])
83+
except Exception as e:
84+
raise ApplyModeError(f"Failed to restore PEFT state for module {name}: {e}")
85+
86+
return model
87+
88+
89+
def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None:
90+
"""Update the PEFT/LoRA state in the metadata dict."""
91+
metadata["peft_state"] = peft_state(model)
92+
93+
94+
def peft_state(model: nn.Module) -> dict[str, Any]:
95+
return {
96+
get_unwrapped_name(n): m.get_peft_state()
97+
for n, m in model.named_modules()
98+
if isinstance(m, LoRAModule)
99+
}
100+
101+
102+
def replace_lora_module(
103+
model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry
104+
):
105+
"""Recursively replace the module with LoRA module."""
106+
# Register custom plugins (e.g., for Megatron distributed checkpointing)
107+
from .custom import register_custom_model_plugins_on_the_fly
69108

70-
def replace_lora_module(model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry):
71-
"""Recursively replace the module with quantized module."""
72-
#TODO: register the extra state for megatron-lm
109+
register_custom_model_plugins_on_the_fly(model)
73110

74111
if type(model) in registry:
75112
model = registry.convert(model)
76113
_replace_lora_module(model, version=version, registry=registry)
77114

115+
78116
def export_peft_model(model: nn.Module, config):
79117
"""Export the quantized model to a quantized model."""
80118
raise NotImplementedError("Exporting a quantized model is not supported yet.")
81119

82120

83-
def restore_export_peft_model(
84-
model: nn.Module, config, metadata: MetadataDict
85-
):
121+
def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict):
86122
"""Restores the quantized model from the given state dict."""
87123
raise NotImplementedError("Restoring a quantized & exported model is not supported yet.")
88124

89125

90-
def _replace_lora_module(model: nn.Module, version=None,registry=LoRAModuleRegistry):
126+
def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry):
91127
for name, child in model.named_children():
92128
if type(child) in registry:
93129
lora_module = registry.convert(child)
@@ -106,3 +142,30 @@ def restore_export_quantized_model(
106142
) -> nn.Module:
107143
"""Restores the quantized model from the given state dict."""
108144
raise NotImplementedError("Restoring a quantized & exported model is not supported yet.")
145+
146+
147+
def update_peft_metadata_in_model(model: nn.Module) -> None:
148+
"""Update the PEFT metadata in the model's ModeloptStateManager.
149+
This function should be called after manually modifying LoRA adapters to ensure
150+
the metadata stored in the ModeloptStateManager reflects the current state.
151+
152+
Args:
153+
model: Model with LoRA modules whose metadata needs updating
154+
Example:
155+
>>> # After manually adding/modifying adapters
156+
>>> for module in model.modules():
157+
... if isinstance(module, LoRAModule):
158+
... module.update_layer_lora("custom_adapter", rank=32)
159+
>>> # Update metadata to reflect changes
160+
>>> update_peft_metadata_in_model(model)
161+
"""
162+
# Check if model has ModeloptStateManager (has been converted with peft mode)
163+
if not ModeloptStateManager.is_converted(model):
164+
return
165+
166+
# Get the state manager
167+
manager = ModeloptStateManager(model)
168+
169+
# Update the metadata with current PEFT state
170+
if manager._state and manager._last_metadata is not None:
171+
manager._last_metadata["peft_state"] = peft_state(model)

modelopt/torch/peft/convert.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,28 @@
1616
"""User-facing quantization API."""
1717

1818
import fnmatch
19-
import inspect
20-
import warnings
21-
from collections.abc import Callable, Iterable
2219
from typing import Any
2320

24-
import torch
2521
import torch.nn as nn
2622

2723
# import modelopt.torch.quantization as mtq
2824
from modelopt.torch.opt import apply_mode
25+
26+
# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
27+
from modelopt.torch.opt.conversion import ModeloptStateManager
28+
2929
# from modelopt.torch.opt.searcher import ForwardLoop
3030
# from modelopt.torch.opt.utils import forward_with_reshard
3131
from modelopt.torch.peft.config import PEFTConfig
32-
# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
32+
33+
from .lora.layer import LoRAModule
3334

3435
# from . import config
3536
# from .algorithms import AutoQuantizeSearcher
3637
# from .config import QuantizeAlgoCfgType
3738
# from .conversion import set_quantizer_attribute
3839
from .mode import PEFTModeRegistry
39-
from .lora.layer import LoRAModule
40+
4041
# from .nn import QuantModule, TensorQuantizer
4142

4243
# __all__ = [
@@ -50,17 +51,19 @@
5051
# "quantize",
5152
# ]
5253

54+
5355
def update_model(
5456
model: nn.Module,
5557
config: dict[str, Any | PEFTConfig],
5658
):
57-
#TODO: deal with extra state, how to save the model
58-
#TODO: sharded dict
59-
#TODO: metadate
60-
#TODO: how to restore the model
59+
# TODO: deal with extra state, how to save the model
60+
# TODO: sharded dict
61+
# TODO: metadate
62+
# TODO: how to restore the model
6163
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
6264
return add_adapter(model, config)
6365

66+
6467
def add_adapter(model, config):
6568
adapter_cfg = config["adapter_cfg"]
6669
adapter_name = config["adapter_name"]
@@ -77,4 +80,34 @@ def add_adapter(model, config):
7780
else:
7881
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
7982
module.update_layer_lora(adapter_name, adapter_setting["rank"])
80-
return model
83+
84+
# Update the metadata in ModeloptStateManager after adding adapters
85+
_update_peft_metadata_in_state(model)
86+
return model
87+
88+
89+
def _update_peft_metadata_in_state(model: nn.Module) -> None:
90+
"""Update the PEFT metadata in the ModeloptStateManager.
91+
92+
This function updates the metadata to reflect the current state of LoRA adapters
93+
after they have been added or modified.
94+
"""
95+
# Check if model has ModeloptStateManager (has been converted with peft mode)
96+
if not ModeloptStateManager.is_converted(model):
97+
return
98+
99+
# Get the state manager
100+
manager = ModeloptStateManager(model)
101+
102+
# Get current PEFT state from all LoRA modules
103+
current_peft_state = {}
104+
for name, module in model.named_modules():
105+
if isinstance(module, LoRAModule):
106+
from modelopt.torch.utils import get_unwrapped_name
107+
108+
unwrapped_name = get_unwrapped_name(name)
109+
current_peft_state[unwrapped_name] = module.get_peft_state()
110+
111+
# Update the metadata in the last mode state (which should be 'peft')
112+
if manager._state and manager._last_metadata is not None:
113+
manager._last_metadata["peft_state"] = current_peft_state

modelopt/torch/peft/custom.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Custom PEFT/LoRA plugins registry."""
17+
18+
# Registry for custom model plugins
19+
CUSTOM_MODEL_PLUGINS = set()
20+
21+
22+
def register_custom_model_plugins_on_the_fly(model):
23+
"""Registers custom PEFT/LoRA plugins on the fly.
24+
25+
This is called before LoRAModule replacement to allow plugins
26+
to configure the model (e.g., for distributed checkpointing).
27+
"""
28+
for callback in CUSTOM_MODEL_PLUGINS:
29+
callback(model)

0 commit comments

Comments
 (0)