Skip to content

Commit 019efb0

Browse files
committed
Update disable/enable logic
Signed-off-by: Jingyu Xin <[email protected]>
1 parent ca9698d commit 019efb0

File tree

3 files changed

+37
-77
lines changed

3 files changed

+37
-77
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,11 @@ def add_adapter(model, config: PEFTConfig):
185185
continue
186186
else:
187187
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
188-
if adapter_setting.enable: # type: ignore[union-attr]
189-
module.update_layer_lora(
190-
adapter_name,
191-
adapter_setting,
192-
)
188+
module.update_layer_lora(
189+
adapter_name,
190+
adapter_setting,
191+
)
193192

194-
_update_peft_metadata_in_state(model)
195193
return model
196194

197195

modelopt/torch/peft/lora/layer.py

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class LoRAModule(DynamicModule):
3131
def _setup(self) -> None:
3232
"""Initialize LoRA-specific attributes."""
3333
self._lora_adapters: dict[str, dict[str, Any]] = {}
34-
self._active_adapters: set = set()
3534

3635
@property
3736
def adapter_names(self) -> set:
@@ -43,39 +42,14 @@ def active_adapters(self) -> set:
4342
"""Return the set of currently active adapter names."""
4443
return self._active_adapters.copy()
4544

46-
def activate_adapter(self, adapter_name: str) -> None:
47-
"""Activate a specific adapter.
48-
49-
Args:
50-
adapter_name: Name of the adapter to activate
51-
52-
Raises:
53-
ValueError: If adapter_name is not registered
54-
"""
55-
if adapter_name not in self._lora_adapters:
56-
raise ValueError(
57-
f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}"
58-
)
59-
self._active_adapters.add(adapter_name)
60-
61-
def deactivate_adapter(self, adapter_name: str) -> None:
62-
"""Deactivate a specific adapter.
63-
64-
Args:
65-
adapter_name: Name of the adapter to deactivate
66-
"""
67-
self._active_adapters.discard(adapter_name)
68-
69-
def activate_all_adapters(self) -> None:
70-
"""Activate all registered adapters."""
71-
self._active_adapters = self.adapter_names.copy()
72-
73-
def deactivate_all_adapters(self) -> None:
74-
"""Deactivate all adapters."""
75-
self._active_adapters.clear()
76-
7745
def _register_adapter(
78-
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float = 1.0
46+
self,
47+
adapter_name: str,
48+
lora_a: nn.Module,
49+
lora_b: nn.Module,
50+
rank: int,
51+
scale: float = 1.0,
52+
enable: bool = True,
7953
) -> None:
8054
"""Register a new LoRA adapter with explicit rank tracking.
8155
@@ -86,21 +60,18 @@ def _register_adapter(
8660
rank: Rank of the LoRA decomposition
8761
scale: Scale factor for the LoRA output
8862
"""
89-
# Add as submodules for proper parameter registration
9063
self.add_module(f"lora_a_{adapter_name}", lora_a)
9164
self.add_module(f"lora_b_{adapter_name}", lora_b)
9265

9366
# Store in adapter dictionary with explicit rank
9467
self._lora_adapters[adapter_name] = {
9568
"lora_a": lora_a,
9669
"lora_b": lora_b,
97-
"rank": rank, # Store rank explicitly for reliability
70+
"rank": rank,
9871
"scale": scale,
72+
"enable": enable,
9973
}
10074

101-
# Automatically activate new adapters
102-
self.activate_adapter(adapter_name)
103-
10475
@abstractmethod
10576
def update_layer_lora(
10677
self,
@@ -156,14 +127,11 @@ def get_peft_state(self) -> dict[str, Any]:
156127

157128
adapters_config[adapter_name] = {
158129
"rank": rank,
159-
"is_active": adapter_name in self._active_adapters,
160-
"lora_a_type": type(lora_a).__name__,
161-
"lora_b_type": type(lora_b).__name__,
130+
"enable": adapter_modules.get("enable", True),
162131
"scale": adapter_modules.get("scale", 1.0),
163132
}
164133

165134
modelopt_state["adapters"] = adapters_config
166-
modelopt_state["active_adapters"] = list(self._active_adapters)
167135

168136
return modelopt_state
169137

@@ -246,41 +214,29 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
246214
Returns:
247215
Output from the base layer plus active LoRA adaptations
248216
"""
249-
# Call the base layer's forward method
250217
output = super().forward(x, *args, **kwargs)
251218

252-
# Handle different output types from base layer
253219
if isinstance(output, tuple):
254-
# If output is a tuple, assume first element is the main result
255220
result = output[0]
256221
other_outputs = output[1:]
257222
else:
258-
# If output is a single tensor
259223
result = output
260224
other_outputs = ()
261225

262-
# Apply active LoRA adapters
263-
if self._active_adapters and self._lora_adapters:
264-
for adapter_name in self._active_adapters:
265-
if adapter_name in self._lora_adapters:
266-
adapter = self._lora_adapters[adapter_name]
267-
# LoRA computation: result = result + B(A(x))
268-
lora_a = adapter["lora_a"]
269-
lora_b = adapter["lora_b"]
270-
271-
# Handle different forward signatures
272-
lora_a_output = lora_a(x)
273-
if isinstance(lora_a_output, tuple):
274-
lora_a_output = lora_a_output[0]
275-
276-
lora_b_output = lora_b(lora_a_output)
277-
if isinstance(lora_b_output, tuple):
278-
lora_b_output = lora_b_output[0]
279-
280-
scale = adapter.get("scale", 1.0)
281-
result = result + scale * lora_b_output
226+
for adapter_name in self._lora_adapters:
227+
adapter = self._lora_adapters[adapter_name]
228+
if adapter["enable"]:
229+
lora_a = adapter["lora_a"]
230+
lora_b = adapter["lora_b"]
231+
lora_a_output = lora_a(x)
232+
if isinstance(lora_a_output, tuple):
233+
lora_a_output = lora_a_output[0]
234+
lora_b_output = lora_b(lora_a_output)
235+
if isinstance(lora_b_output, tuple):
236+
lora_b_output = lora_b_output[0]
237+
scale = adapter["scale"]
238+
result = result + scale * lora_b_output
282239

283-
# Return output in the same format as the base layer
284240
if other_outputs:
285241
return (result, *other_outputs)
286242
else:

modelopt/torch/peft/lora/tp_layer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def _get_init_methods(self, lora_a_init, lora_b_init) -> tuple[Callable, Callabl
4646
return lora_a_init, lora_b_init
4747

4848
def _register_adapter_with_device(
49-
self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float
49+
self,
50+
adapter_name: str,
51+
lora_a: nn.Module,
52+
lora_b: nn.Module,
53+
rank: int,
54+
scale: float,
55+
enable: bool,
5056
) -> None:
5157
"""Register LoRA adapter modules and ensure correct device placement.
5258
@@ -78,7 +84,7 @@ def _register_adapter_with_device(
7884
lora_a = lora_a.to(dtype)
7985
lora_b = lora_b.to(dtype)
8086

81-
super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale)
87+
super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale, enable)
8288

8389

8490
@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"})
@@ -120,7 +126,7 @@ def update_layer_lora(
120126
)
121127

122128
self._register_adapter_with_device(
123-
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale
129+
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
124130
)
125131

126132

@@ -163,7 +169,7 @@ def update_layer_lora(
163169
)
164170

165171
self._register_adapter_with_device(
166-
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale
172+
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
167173
)
168174

169175

0 commit comments

Comments
 (0)