Skip to content

Commit 98ef9fb

Browse files
committed
Update: removed the permodule restore and state
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 1bb3985 commit 98ef9fb

File tree

4 files changed

+32
-227
lines changed

4 files changed

+32
-227
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
import fnmatch
1919
from collections.abc import Callable, Iterable
20-
from typing import Any
2120

2221
import torch.nn as nn
2322

24-
from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager
23+
from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager
2524
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
26-
from modelopt.torch.utils import get_unwrapped_name
2725

2826
from .config import PEFTConfig
2927
from .lora.layer import LoRAModule, LoRAModuleRegistry
@@ -34,7 +32,6 @@
3432
"replace_lora_module",
3533
"unfreeze_base_weights",
3634
"unfreeze_lora_weights",
37-
"update_peft_metadata_in_model",
3835
]
3936

4037

@@ -48,64 +45,17 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4845
metadata = {}
4946
add_adapter(model, config)
5047
update_grads(model, config)
51-
update_peft_metadata(model, config, metadata)
5248

5349
return model, metadata
5450

5551

5652
def restore_peft_model(
5753
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict
5854
) -> nn.Module:
59-
convert_to_peft_model(model, config)
60-
return restore_peft_state(model, metadata)
61-
62-
63-
def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict):
64-
"""Restore PEFT state from metadata or extra_state.
65-
66-
For backward compatibility, we check metadata first. For distributed
67-
checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule
68-
and will be restored automatically via set_extra_state() during load_state_dict().
69-
70-
Args:
71-
model: Model with LoRA modules to restore
72-
metadata: Metadata dictionary that may contain peft_state
73-
Returns:
74-
The model with restored PEFT state
75-
"""
76-
if "peft_state" not in metadata:
77-
# For distributed checkpoints (NeMo-MCore), peft_state is stored
78-
# in each LoRAModule's extra_state and will be restored via
79-
# set_extra_state() during load_state_dict()
80-
return model
81-
82-
# Legacy path: restore from metadata
83-
peft_state_dict = metadata["peft_state"]
84-
for name, module in model.named_modules():
85-
if isinstance(module, LoRAModule):
86-
unwrapped_name = get_unwrapped_name(name)
87-
if unwrapped_name in peft_state_dict:
88-
try:
89-
module.set_from_peft_state(peft_state_dict[unwrapped_name])
90-
except Exception as e:
91-
raise ApplyModeError(f"Failed to restore PEFT state for module {name}: {e}")
92-
55+
model, _ = convert_to_peft_model(model, config)
9356
return model
9457

9558

96-
def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None:
97-
"""Update the PEFT/LoRA state in the metadata dict."""
98-
metadata["peft_state"] = peft_state(model)
99-
100-
101-
def peft_state(model: nn.Module) -> dict[str, Any]:
102-
return {
103-
get_unwrapped_name(n): m.get_peft_state()
104-
for n, m in model.named_modules()
105-
if isinstance(m, LoRAModule)
106-
}
107-
108-
10959
def replace_lora_module(
11060
model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry
11161
):
@@ -137,32 +87,8 @@ def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegi
13787
_replace_lora_module(getattr(model, name), version=version, registry=registry)
13888

13989

140-
def update_peft_metadata_in_model(model: nn.Module) -> None:
141-
"""Update the PEFT metadata in the model's ModeloptStateManager.
142-
143-
This function should be called after manually modifying LoRA adapters to ensure
144-
the metadata stored in the ModeloptStateManager reflects the current state.
145-
146-
Args:
147-
model: Model with LoRA modules whose metadata needs updating
148-
Example:
149-
>>> # After manually adding/modifying adapters
150-
>>> for module in model.modules():
151-
... if isinstance(module, LoRAModule):
152-
... module.update_layer_lora("custom_adapter", rank=32)
153-
>>> # Update metadata to reflect changes
154-
>>> update_peft_metadata_in_model(model)
155-
"""
156-
# Check if model has ModeloptStateManager (has been converted with peft mode)
157-
if not ModeloptStateManager.is_converted(model):
158-
return
159-
160-
# Get the state manager
161-
manager = ModeloptStateManager(model)
162-
163-
# Update the metadata with current PEFT state
164-
if manager._state and manager._last_metadata is not None:
165-
manager._last_metadata["peft_state"] = peft_state(model)
90+
def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None:
91+
"""Placeholder for the metadata-related function; not needed in this mode."""
16692

16793

16894
def add_adapter(model, config: PEFTConfig):

modelopt/torch/peft/convert.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
__all__ = [
4040
"disable_adapters",
4141
"enable_adapters",
42-
"get_adapter_states",
4342
"is_peft_model",
4443
"update_model",
4544
]
@@ -192,46 +191,6 @@ def enable_adapters(model, layers_to_enable=None, adapters_to_enable=None):
192191
)
193192

194193

195-
def get_adapter_states(model):
196-
"""Get the current state of all adapters in the model.
197-
198-
Args:
199-
model: Model with LoRA adapters
200-
201-
Returns:
202-
Dict mapping module names to their adapter states
203-
204-
Example:
205-
>>> states = get_adapter_states(model)
206-
>>> print(states)
207-
{
208-
'transformer.layers.0.attention': {
209-
'default': {'enabled': True, 'rank': 32},
210-
'finetuned': {'enabled': False, 'rank': 64}
211-
},
212-
'transformer.layers.0.mlp': {
213-
'default': {'enabled': True, 'rank': 32}
214-
}
215-
}
216-
"""
217-
assert is_peft_model(model), "It's not a MO-PEFT model"
218-
219-
adapter_states = {}
220-
for module_name, module in model.named_modules():
221-
if isinstance(module, LoRAModule):
222-
module_adapters = {}
223-
for adapter_name, adapter_dict in module._lora_adapters.items():
224-
module_adapters[adapter_name] = {
225-
"enabled": adapter_dict.get("enable", True),
226-
"rank": adapter_dict.get("rank", "unknown"),
227-
"scale": adapter_dict.get("scale", 1.0),
228-
}
229-
if module_adapters:
230-
adapter_states[module_name] = module_adapters
231-
232-
return adapter_states
233-
234-
235194
def is_megatron_core_model(model) -> bool:
236195
if MEGATRON_LAYERS:
237196
for m in model.modules():

modelopt/torch/peft/lora/layer.py

Lines changed: 5 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""LoRA (Low-Rank Adaptation) module implementation."""
22

33
import math
4-
import warnings
54
from abc import abstractmethod
65
from typing import Any
76

@@ -28,6 +27,11 @@ def get_init_methods(init_method: str = "kaiming_init"):
2827
) # LoRA A: Kaiming uniform
2928
elif init_method == "zero_init":
3029
return lambda weight: init.zeros_(weight) # LoRA B: zeros
30+
else:
31+
raise ValueError(
32+
f"Unsupported initialization method: '{init_method}'. "
33+
"Supported methods: 'kaiming_init', 'zero_init'"
34+
)
3135

3236

3337
class LoRAModule(DynamicModule):
@@ -98,108 +102,6 @@ def update_layer_lora(
98102
"""
99103
raise NotImplementedError("Subclasses must implement update_layer_lora")
100104

101-
def get_peft_state(self) -> dict[str, Any]:
102-
"""Get PEFT/LoRA state to be saved in checkpoint.
103-
104-
This method returns the configuration and state of all LoRA adapters
105-
without including the actual weight tensors.
106-
107-
Returns:
108-
Dictionary containing:
109-
- adapters: Dict mapping adapter names to their configuration
110-
"""
111-
modelopt_state = {}
112-
113-
# Store adapter configurations
114-
adapters_config = {}
115-
for adapter_name, adapter_modules in self._lora_adapters.items():
116-
lora_a = adapter_modules["lora_a"]
117-
lora_b = adapter_modules["lora_b"]
118-
119-
# Get explicitly stored rank for reliability
120-
rank = adapter_modules.get("rank", None)
121-
122-
# If rank is not stored (legacy case), try to infer it
123-
if rank is None:
124-
if hasattr(lora_a, "output_size"):
125-
rank = lora_a.output_size
126-
elif hasattr(lora_b, "input_size"):
127-
rank = lora_b.input_size
128-
elif hasattr(lora_a, "out_features"):
129-
rank = lora_a.out_features
130-
elif hasattr(lora_b, "in_features"):
131-
rank = lora_b.in_features
132-
133-
adapters_config[adapter_name] = {
134-
"rank": rank,
135-
"enable": adapter_modules.get("enable", True),
136-
"scale": adapter_modules.get("scale", 1.0),
137-
}
138-
139-
modelopt_state["adapters"] = adapters_config
140-
141-
return modelopt_state
142-
143-
def get_extra_state(self) -> dict[str, Any]:
144-
"""Get extra state for distributed checkpointing.
145-
146-
For distributed/sharded checkpoints (like NeMo-MCore), we store the PEFT state
147-
as extra_state instead of in metadata. This handles cases where module names
148-
change with different parallelism settings (TP, PP, EP).
149-
150-
Returns:
151-
Dictionary containing the PEFT/LoRA adapter state
152-
"""
153-
# Only return state if we have adapters
154-
if not self._lora_adapters:
155-
return {}
156-
157-
# Get the current PEFT state
158-
peft_state = self.get_peft_state()
159-
160-
return {"modelopt_peft_state": peft_state}
161-
162-
def set_from_peft_state(self, peft_state: dict[str, Any]) -> None:
163-
"""Restore LoRA adapters from saved PEFT state.
164-
165-
This method recreates LoRA adapters based on their saved configuration.
166-
Note: This only restores the adapter structure, not the weights.
167-
168-
Args:
169-
peft_state: Dictionary containing adapter configurations
170-
"""
171-
adapters_config = peft_state.get("adapters", {})
172-
173-
for adapter_name, config in adapters_config.items():
174-
if adapter_name not in self._lora_adapters:
175-
self.update_layer_lora(adapter_name, config)
176-
177-
def set_extra_state(self, state: dict[str, Any]) -> None:
178-
"""Restore extra state for distributed checkpointing.
179-
180-
This method is called during load_state_dict() to restore the PEFT/LoRA state
181-
from distributed checkpoints. It handles the adapter configuration but not
182-
the actual weights (which are restored through the normal state_dict mechanism).
183-
184-
Args:
185-
state: Dictionary containing the extra state to restore
186-
"""
187-
if state is None:
188-
return
189-
190-
peft_state = state.get("modelopt_peft_state")
191-
if peft_state is None:
192-
return
193-
194-
# Restore the PEFT state
195-
try:
196-
self.set_from_peft_state(peft_state)
197-
except Exception as e:
198-
warnings.warn(
199-
f"Failed to restore PEFT state from extra_state: {e}. "
200-
"This might happen if the model structure has changed."
201-
)
202-
203105
def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
204106
"""Forward pass with LoRA adaptation.
205107

modelopt/torch/peft/mode.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,69 @@
1+
"""PEFT mode descriptors for model optimization."""
2+
13
from modelopt.torch.opt.config import ModeloptBaseConfig
24
from modelopt.torch.opt.mode import (
35
ConvertEntrypoint,
4-
ConvertReturnType,
5-
ModeConfigList,
66
ModeDescriptor,
77
RestoreEntrypoint,
88
UpdateEntrypoint,
99
_ModeRegistryCls,
1010
)
11-
from .config import PEFTConfig, ExportPEFTConfig
12-
from .conversion import convert_to_peft_model, restore_peft_model, update_peft_metadata, export_peft_model, restore_export_peft_model
11+
12+
from .config import ExportPEFTConfig, PEFTConfig
13+
from .conversion import (
14+
convert_to_peft_model,
15+
export_peft_model,
16+
restore_export_peft_model,
17+
restore_peft_model,
18+
update_peft_metadata,
19+
)
1320

1421
PEFTModeRegistry = _ModeRegistryCls("PEFT")
1522

23+
1624
@PEFTModeRegistry.register_mode
1725
class PEFTModeDescriptor(ModeDescriptor):
26+
"""Mode descriptor for PEFT (Parameter-Efficient Fine-Tuning) mode."""
27+
1828
@property
1929
def name(self) -> str:
30+
"""Returns the value (str representation) of the mode."""
2031
return "peft"
2132

2233
@property
2334
def config_class(self) -> type[ModeloptBaseConfig]:
35+
"""Specifies the config class for the mode."""
2436
return PEFTConfig
2537

2638
@property
2739
def export_mode(self) -> str | None:
40+
"""Specifies the export mode name for this mode."""
2841
return "export_peft"
2942

3043
@property
3144
def convert(self) -> ConvertEntrypoint:
45+
"""The mode's entrypoint for converting a model."""
3246
return convert_to_peft_model
3347

3448
@property
3549
def restore(self) -> RestoreEntrypoint:
50+
"""The mode's entrypoint for restoring a model."""
3651
return restore_peft_model
3752

3853
@property
3954
def update_for_save(self) -> UpdateEntrypoint:
55+
"""The mode's entrypoint for updating the model's state before saving."""
4056
return update_peft_metadata
4157

4258
@property
4359
def update_for_new_mode(self) -> UpdateEntrypoint:
4460
"""The mode's entrypoint for updating the models state before new mode."""
4561
return update_peft_metadata
4662

63+
4764
@PEFTModeRegistry.register_mode
4865
class ExportPEFTModeDescriptor(ModeDescriptor):
66+
"""Mode descriptor for exporting PEFT models."""
4967

5068
@property
5169
def name(self) -> str:
@@ -70,4 +88,4 @@ def convert(self) -> ConvertEntrypoint:
7088
@property
7189
def restore(self) -> RestoreEntrypoint:
7290
"""The mode's entrypoint for restoring a model."""
73-
return restore_export_peft_model
91+
return restore_export_peft_model

0 commit comments

Comments
 (0)