Skip to content

Commit 37266c1

Browse files
FIX Revert optimization for LoRA scaling == 1 (#2416)
The PR #2404 introduced an optimization for LoRA in case that scaling == 1 (see #2404 (comment)). This unfortunately leads to recompilation when the model is compiled, as witnessed by the failing CI here: https://github.com/huggingface/peft/actions/runs/13755365121/job/38461837691#step:6:157 For now, let's revert the optimization. If we have concrete numbers that show that the optimization makes a significant difference, we can start thinking about how to optimize this code path in a compile-friendly way.
1 parent 8edaae9 commit 37266c1

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -489,13 +489,7 @@ def _mixed_batch_forward(
489489
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
490490
# layer output
491491
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
492-
493-
# Loras such as EoRA will always be scaling == 1 so we can skip the no-op math
494-
if scaling == 1:
495-
lora_output = lora_B(lora_A(dropout(sub_batch)))
496-
else:
497-
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
498-
492+
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
499493
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
500494

501495
return result
@@ -730,11 +724,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
730724
x = self._cast_input_dtype(x, lora_A.weight.dtype)
731725

732726
if not self.use_dora[active_adapter]:
733-
# Loras such as EoRA will always be scaling == 1 so we can skip the no-op math
734-
if scaling == 1:
735-
result = result + lora_B(lora_A(dropout(x)))
736-
else:
737-
result = result + lora_B(lora_A(dropout(x))) * scaling
727+
result = result + lora_B(lora_A(dropout(x))) * scaling
738728
else:
739729
if isinstance(dropout, nn.Identity) or not self.training:
740730
base_result = result

0 commit comments

Comments
 (0)