@@ -73,7 +73,7 @@ def deactivate_all_adapters(self) -> None:
73
73
self ._active_adapters .clear ()
74
74
75
75
def _register_adapter (
76
- self , adapter_name : str , lora_a : nn .Module , lora_b : nn .Module , rank : int
76
+ self , adapter_name : str , lora_a : nn .Module , lora_b : nn .Module , rank : int , scale : float = 1.0
77
77
) -> None :
78
78
"""Register a new LoRA adapter with explicit rank tracking.
79
79
@@ -82,6 +82,7 @@ def _register_adapter(
82
82
lora_a: LoRA A module (down-projection)
83
83
lora_b: LoRA B module (up-projection)
84
84
rank: Rank of the LoRA decomposition
85
+ scale: Scale factor for the LoRA output
85
86
"""
86
87
# Add as submodules for proper parameter registration
87
88
self .add_module (f"lora_a_{ adapter_name } " , lora_a )
@@ -92,13 +93,14 @@ def _register_adapter(
92
93
"lora_a" : lora_a ,
93
94
"lora_b" : lora_b ,
94
95
"rank" : rank , # Store rank explicitly for reliability
96
+ "scale" : scale ,
95
97
}
96
98
97
99
# Automatically activate new adapters
98
100
self .activate_adapter (adapter_name )
99
101
100
102
@abstractmethod
101
- def update_layer_lora (self , adapter_name : str , rank : int = 64 ) -> None :
103
+ def update_layer_lora (self , adapter_name : str , rank : int = 64 , scale : float = 1.0 ) -> None :
102
104
"""Create and register a new LoRA adapter.
103
105
104
106
This method must be implemented by subclasses to create the appropriate
@@ -107,6 +109,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None:
107
109
Args:
108
110
adapter_name: Name for the new adapter
109
111
rank: Rank of the LoRA decomposition (default: 64)
112
+ scale: Scale factor for the LoRA output (default: 1.0)
110
113
"""
111
114
raise NotImplementedError ("Subclasses must implement update_layer_lora" )
112
115
@@ -148,14 +151,12 @@ def get_peft_state(self) -> dict[str, Any]:
148
151
"is_active" : adapter_name in self ._active_adapters ,
149
152
"lora_a_type" : type (lora_a ).__name__ ,
150
153
"lora_b_type" : type (lora_b ).__name__ ,
154
+ "scale" : adapter_modules .get ("scale" , 1.0 ),
151
155
}
152
156
153
157
modelopt_state ["adapters" ] = adapters_config
154
158
modelopt_state ["active_adapters" ] = list (self ._active_adapters )
155
159
156
- # Store the base module type for validation
157
- modelopt_state ["base_module_type" ] = type (self ).__name__
158
-
159
160
return modelopt_state
160
161
161
162
def get_extra_state (self ) -> dict [str , Any ]:
@@ -177,6 +178,36 @@ def get_extra_state(self) -> dict[str, Any]:
177
178
178
179
return {"modelopt_peft_state" : peft_state }
179
180
181
+ def set_from_peft_state (self , peft_state : dict [str , Any ]) -> None :
182
+ """Restore LoRA adapters from saved PEFT state.
183
+
184
+ This method recreates LoRA adapters based on their saved configuration.
185
+ Note: This only restores the adapter structure, not the weights.
186
+
187
+ Args:
188
+ peft_state: Dictionary containing adapter configurations
189
+ """
190
+ adapters_config = peft_state .get ("adapters" , {})
191
+
192
+ # Clear existing adapters first
193
+ self ._lora_adapters .clear ()
194
+ self ._active_adapters .clear ()
195
+
196
+ # Recreate each adapter based on saved configuration
197
+ for adapter_name , config in adapters_config .items ():
198
+ rank = config .get ("rank" )
199
+ scale = config .get ("scale" , 1.0 )
200
+
201
+ if rank is not None :
202
+ # Create the adapter with saved configuration
203
+ self .update_layer_lora (adapter_name , rank = rank , scale = scale )
204
+
205
+ # Set activation state
206
+ if config .get ("is_active" , False ):
207
+ self .activate_adapter (adapter_name )
208
+ else :
209
+ self .deactivate_adapter (adapter_name )
210
+
180
211
def set_extra_state (self , state : dict [str , Any ]) -> None :
181
212
"""Restore extra state for distributed checkpointing.
182
213
@@ -245,7 +276,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
245
276
if isinstance (lora_b_output , tuple ):
246
277
lora_b_output = lora_b_output [0 ]
247
278
248
- result = result + lora_b_output
279
+ scale = adapter .get ("scale" , 1.0 )
280
+ result = result + scale * lora_b_output
249
281
250
282
# Return output in the same format as the base layer
251
283
if other_outputs :
0 commit comments