15
15
16
16
"""Quantization conversion/restore utilities."""
17
17
18
- import fnmatch
19
- from collections .abc import Callable
20
- from contextlib import contextmanager
21
18
from typing import Any
22
19
23
20
import torch .nn as nn
24
21
25
22
from modelopt .torch .opt .conversion import ApplyModeError , ModelLikeModule , ModeloptStateManager
26
- from modelopt .torch .opt .dynamic import _DMRegistryCls
27
23
from modelopt .torch .opt .mode import ConvertReturnType , MetadataDict
28
24
from modelopt .torch .utils import get_unwrapped_name
29
25
30
- from .config import (
31
- PEFTConfig ,
32
- _QuantizeExportConfig ,
33
- )
34
- from .lora .layer import LoRAModuleRegistry
26
+ from .config import PEFTConfig , _QuantizeExportConfig
27
+ from .lora .layer import LoRAModule , LoRAModuleRegistry
35
28
36
29
__all__ = [
37
30
"replace_lora_module" ,
31
+ "update_peft_metadata_in_model" ,
38
32
]
39
33
40
34
@@ -48,46 +42,88 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
48
42
# set_quantizer_by_cfg(model, config.get("quant_cfg", {}))
49
43
50
44
metadata = {}
51
- # update_quantize_metadata (model, config, metadata)
45
+ update_peft_metadata (model , config , metadata )
52
46
53
47
return model , metadata
54
48
49
+
55
50
def restore_peft_model (
56
51
model : ModelLikeModule , config : PEFTConfig , metadata : MetadataDict
57
52
) -> nn .Module :
58
- #TODO: implemente the restore logic
59
- pass
60
-
61
-
62
-
63
- def update_peft_metadata (
64
- model : nn .Module , config : PEFTConfig , metadata : MetadataDict
65
- ) -> None :
66
- """Update the quantizer state in the metadata dict."""
67
- pass
68
-
53
+ convert_to_peft_model (model , config )
54
+ return restore_peft_state (model , metadata )
55
+
56
+
57
+ def restore_peft_state (model : ModelLikeModule , metadata : MetadataDict ):
58
+ """Restore PEFT state from metadata or extra_state.
59
+ For backward compatibility, we check metadata first. For distributed
60
+ checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule
61
+ and will be restored automatically via set_extra_state() during load_state_dict().
62
+
63
+ Args:
64
+ model: Model with LoRA modules to restore
65
+ metadata: Metadata dictionary that may contain peft_state
66
+ Returns:
67
+ The model with restored PEFT state
68
+ """
69
+ if "peft_state" not in metadata :
70
+ # For distributed checkpoints (NeMo-MCore), peft_state is stored
71
+ # in each LoRAModule's extra_state and will be restored via
72
+ # set_extra_state() during load_state_dict()
73
+ return model
74
+
75
+ # Legacy path: restore from metadata
76
+ peft_state_dict = metadata ["peft_state" ]
77
+ for name , module in model .named_modules ():
78
+ if isinstance (module , LoRAModule ):
79
+ unwrapped_name = get_unwrapped_name (name )
80
+ if unwrapped_name in peft_state_dict :
81
+ try :
82
+ module .set_from_peft_state (peft_state_dict [unwrapped_name ])
83
+ except Exception as e :
84
+ raise ApplyModeError (f"Failed to restore PEFT state for module { name } : { e } " )
85
+
86
+ return model
87
+
88
+
89
+ def update_peft_metadata (model : nn .Module , config : PEFTConfig , metadata : MetadataDict ) -> None :
90
+ """Update the PEFT/LoRA state in the metadata dict."""
91
+ metadata ["peft_state" ] = peft_state (model )
92
+
93
+
94
+ def peft_state (model : nn .Module ) -> dict [str , Any ]:
95
+ return {
96
+ get_unwrapped_name (n ): m .get_peft_state ()
97
+ for n , m in model .named_modules ()
98
+ if isinstance (m , LoRAModule )
99
+ }
100
+
101
+
102
+ def replace_lora_module (
103
+ model : nn .Module , version = None , config : PEFTConfig = None , registry = LoRAModuleRegistry
104
+ ):
105
+ """Recursively replace the module with LoRA module."""
106
+ # Register custom plugins (e.g., for Megatron distributed checkpointing)
107
+ from .custom import register_custom_model_plugins_on_the_fly
69
108
70
- def replace_lora_module (model : nn .Module , version = None , config : PEFTConfig = None , registry = LoRAModuleRegistry ):
71
- """Recursively replace the module with quantized module."""
72
- #TODO: register the extra state for megatron-lm
109
+ register_custom_model_plugins_on_the_fly (model )
73
110
74
111
if type (model ) in registry :
75
112
model = registry .convert (model )
76
113
_replace_lora_module (model , version = version , registry = registry )
77
114
115
+
78
116
def export_peft_model (model : nn .Module , config ):
79
117
"""Export the quantized model to a quantized model."""
80
118
raise NotImplementedError ("Exporting a quantized model is not supported yet." )
81
119
82
120
83
- def restore_export_peft_model (
84
- model : nn .Module , config , metadata : MetadataDict
85
- ):
121
+ def restore_export_peft_model (model : nn .Module , config , metadata : MetadataDict ):
86
122
"""Restores the quantized model from the given state dict."""
87
123
raise NotImplementedError ("Restoring a quantized & exported model is not supported yet." )
88
124
89
125
90
- def _replace_lora_module (model : nn .Module , version = None ,registry = LoRAModuleRegistry ):
126
+ def _replace_lora_module (model : nn .Module , version = None , registry = LoRAModuleRegistry ):
91
127
for name , child in model .named_children ():
92
128
if type (child ) in registry :
93
129
lora_module = registry .convert (child )
@@ -106,3 +142,30 @@ def restore_export_quantized_model(
106
142
) -> nn .Module :
107
143
"""Restores the quantized model from the given state dict."""
108
144
raise NotImplementedError ("Restoring a quantized & exported model is not supported yet." )
145
+
146
+
147
+ def update_peft_metadata_in_model (model : nn .Module ) -> None :
148
+ """Update the PEFT metadata in the model's ModeloptStateManager.
149
+ This function should be called after manually modifying LoRA adapters to ensure
150
+ the metadata stored in the ModeloptStateManager reflects the current state.
151
+
152
+ Args:
153
+ model: Model with LoRA modules whose metadata needs updating
154
+ Example:
155
+ >>> # After manually adding/modifying adapters
156
+ >>> for module in model.modules():
157
+ ... if isinstance(module, LoRAModule):
158
+ ... module.update_layer_lora("custom_adapter", rank=32)
159
+ >>> # Update metadata to reflect changes
160
+ >>> update_peft_metadata_in_model(model)
161
+ """
162
+ # Check if model has ModeloptStateManager (has been converted with peft mode)
163
+ if not ModeloptStateManager .is_converted (model ):
164
+ return
165
+
166
+ # Get the state manager
167
+ manager = ModeloptStateManager (model )
168
+
169
+ # Update the metadata with current PEFT state
170
+ if manager ._state and manager ._last_metadata is not None :
171
+ manager ._last_metadata ["peft_state" ] = peft_state (model )
0 commit comments