@@ -31,7 +31,6 @@ class LoRAModule(DynamicModule):
31
31
def _setup (self ) -> None :
32
32
"""Initialize LoRA-specific attributes."""
33
33
self ._lora_adapters : dict [str , dict [str , Any ]] = {}
34
- self ._active_adapters : set = set ()
35
34
36
35
@property
37
36
def adapter_names (self ) -> set :
@@ -43,39 +42,14 @@ def active_adapters(self) -> set:
43
42
"""Return the set of currently active adapter names."""
44
43
return self ._active_adapters .copy ()
45
44
46
- def activate_adapter (self , adapter_name : str ) -> None :
47
- """Activate a specific adapter.
48
-
49
- Args:
50
- adapter_name: Name of the adapter to activate
51
-
52
- Raises:
53
- ValueError: If adapter_name is not registered
54
- """
55
- if adapter_name not in self ._lora_adapters :
56
- raise ValueError (
57
- f"Adapter '{ adapter_name } ' not found. Available: { list (self ._lora_adapters .keys ())} "
58
- )
59
- self ._active_adapters .add (adapter_name )
60
-
61
- def deactivate_adapter (self , adapter_name : str ) -> None :
62
- """Deactivate a specific adapter.
63
-
64
- Args:
65
- adapter_name: Name of the adapter to deactivate
66
- """
67
- self ._active_adapters .discard (adapter_name )
68
-
69
- def activate_all_adapters (self ) -> None :
70
- """Activate all registered adapters."""
71
- self ._active_adapters = self .adapter_names .copy ()
72
-
73
- def deactivate_all_adapters (self ) -> None :
74
- """Deactivate all adapters."""
75
- self ._active_adapters .clear ()
76
-
77
45
def _register_adapter (
78
- self , adapter_name : str , lora_a : nn .Module , lora_b : nn .Module , rank : int , scale : float = 1.0
46
+ self ,
47
+ adapter_name : str ,
48
+ lora_a : nn .Module ,
49
+ lora_b : nn .Module ,
50
+ rank : int ,
51
+ scale : float = 1.0 ,
52
+ enable : bool = True ,
79
53
) -> None :
80
54
"""Register a new LoRA adapter with explicit rank tracking.
81
55
@@ -86,21 +60,18 @@ def _register_adapter(
86
60
rank: Rank of the LoRA decomposition
87
61
scale: Scale factor for the LoRA output
88
62
"""
89
- # Add as submodules for proper parameter registration
90
63
self .add_module (f"lora_a_{ adapter_name } " , lora_a )
91
64
self .add_module (f"lora_b_{ adapter_name } " , lora_b )
92
65
93
66
# Store in adapter dictionary with explicit rank
94
67
self ._lora_adapters [adapter_name ] = {
95
68
"lora_a" : lora_a ,
96
69
"lora_b" : lora_b ,
97
- "rank" : rank , # Store rank explicitly for reliability
70
+ "rank" : rank ,
98
71
"scale" : scale ,
72
+ "enable" : enable ,
99
73
}
100
74
101
- # Automatically activate new adapters
102
- self .activate_adapter (adapter_name )
103
-
104
75
@abstractmethod
105
76
def update_layer_lora (
106
77
self ,
@@ -156,14 +127,11 @@ def get_peft_state(self) -> dict[str, Any]:
156
127
157
128
adapters_config [adapter_name ] = {
158
129
"rank" : rank ,
159
- "is_active" : adapter_name in self ._active_adapters ,
160
- "lora_a_type" : type (lora_a ).__name__ ,
161
- "lora_b_type" : type (lora_b ).__name__ ,
130
+ "enable" : adapter_modules .get ("enable" , True ),
162
131
"scale" : adapter_modules .get ("scale" , 1.0 ),
163
132
}
164
133
165
134
modelopt_state ["adapters" ] = adapters_config
166
- modelopt_state ["active_adapters" ] = list (self ._active_adapters )
167
135
168
136
return modelopt_state
169
137
@@ -246,41 +214,29 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
246
214
Returns:
247
215
Output from the base layer plus active LoRA adaptations
248
216
"""
249
- # Call the base layer's forward method
250
217
output = super ().forward (x , * args , ** kwargs )
251
218
252
- # Handle different output types from base layer
253
219
if isinstance (output , tuple ):
254
- # If output is a tuple, assume first element is the main result
255
220
result = output [0 ]
256
221
other_outputs = output [1 :]
257
222
else :
258
- # If output is a single tensor
259
223
result = output
260
224
other_outputs = ()
261
225
262
- # Apply active LoRA adapters
263
- if self ._active_adapters and self ._lora_adapters :
264
- for adapter_name in self ._active_adapters :
265
- if adapter_name in self ._lora_adapters :
266
- adapter = self ._lora_adapters [adapter_name ]
267
- # LoRA computation: result = result + B(A(x))
268
- lora_a = adapter ["lora_a" ]
269
- lora_b = adapter ["lora_b" ]
270
-
271
- # Handle different forward signatures
272
- lora_a_output = lora_a (x )
273
- if isinstance (lora_a_output , tuple ):
274
- lora_a_output = lora_a_output [0 ]
275
-
276
- lora_b_output = lora_b (lora_a_output )
277
- if isinstance (lora_b_output , tuple ):
278
- lora_b_output = lora_b_output [0 ]
279
-
280
- scale = adapter .get ("scale" , 1.0 )
281
- result = result + scale * lora_b_output
226
+ for adapter_name in self ._lora_adapters :
227
+ adapter = self ._lora_adapters [adapter_name ]
228
+ if adapter ["enable" ]:
229
+ lora_a = adapter ["lora_a" ]
230
+ lora_b = adapter ["lora_b" ]
231
+ lora_a_output = lora_a (x )
232
+ if isinstance (lora_a_output , tuple ):
233
+ lora_a_output = lora_a_output [0 ]
234
+ lora_b_output = lora_b (lora_a_output )
235
+ if isinstance (lora_b_output , tuple ):
236
+ lora_b_output = lora_b_output [0 ]
237
+ scale = adapter ["scale" ]
238
+ result = result + scale * lora_b_output
282
239
283
- # Return output in the same format as the base layer
284
240
if other_outputs :
285
241
return (result , * other_outputs )
286
242
else :
0 commit comments