Skip to content

Commit bc41800

Browse files
ENH Allow disabling input dtype casting for LoRA (huggingface#2353)
Provides the disable_input_dtype_casting to prevent the input dtype to be cast during the forward call of a PEFT layer. Normally, the dtype of the weight and input need to match, which is why the dtype is cast. However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to disable it is given. Right now, this only supports LoRA. LoKr and LoHa don't cast the input dtype anyway. Therefore, the PEFT methods most relevant for diffusers are covered.
1 parent fac55ff commit bc41800

File tree

14 files changed

+181
-30
lines changed

14 files changed

+181
-30
lines changed

docs/source/package_reference/helpers.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
1414
## Temporarily Rescaling Adapter Scale in LoraLayer Modules
1515

1616
[[autodoc]] helpers.rescale_adapter_scale
17-
- all
17+
- all
18+
19+
## Context manager to disable input dtype casting in the `forward` method of LoRA layers
20+
21+
[[autodoc]] helpers.disable_input_dtype_casting
22+
- all

src/peft/helpers.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from functools import update_wrapper
1919
from types import MethodType
2020

21+
from torch import nn
22+
2123
from .peft_model import PeftConfig, PeftModel
22-
from .tuners.lora.layer import LoraLayer
24+
from .tuners.lora import LoraLayer
2325

2426

2527
def update_forward_signature(model: PeftModel) -> None:
@@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier):
209211
# restore original scaling values after exiting the context
210212
for module, scaling in original_scaling.items():
211213
module.scaling = scaling
214+
215+
216+
@contextmanager
217+
def disable_input_dtype_casting(model: nn.Module, active: bool = True):
218+
"""
219+
Context manager disables input dtype casting to the dtype of the weight.
220+
221+
Currently specifically works for LoRA.
222+
223+
Parameters:
224+
model (nn.Module):
225+
The model containing PEFT modules whose input dtype casting is to be adjusted.
226+
active (bool):
227+
Whether the context manager is active (default) or inactive.
228+
229+
"""
230+
# Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast.
231+
# However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in
232+
# diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to
233+
# disable it is given.
234+
if not active:
235+
yield
236+
return
237+
238+
original_values = {}
239+
for name, module in model.named_modules():
240+
if not isinstance(module, LoraLayer):
241+
continue
242+
original_values[name] = module.cast_input_dtype_enabled
243+
module.cast_input_dtype_enabled = False
244+
245+
try:
246+
yield
247+
finally:
248+
for name, module in model.named_modules():
249+
if not isinstance(module, LoraLayer):
250+
continue
251+
if name in original_values:
252+
module.cast_input_dtype_enabled = original_values[name]

src/peft/tuners/adalora/bnb.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
129129
requires_conversion = not torch.is_autocast_enabled()
130130
if requires_conversion:
131131
expected_dtype = result.dtype
132-
compute_dtype = lora_A.dtype
133-
if x.dtype != compute_dtype:
134-
x = x.to(compute_dtype)
132+
x = self._cast_input_dtype(x, lora_A.dtype)
135133

136134
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
137135
if requires_conversion:

src/peft/tuners/adalora/gptq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5555
requires_conversion = not torch.is_autocast_enabled()
5656
if requires_conversion:
5757
expected_dtype = result.dtype
58-
if x.dtype != torch.float32:
59-
x = x.float()
58+
x = self._cast_input_dtype(x, torch.float32)
6059

6160
output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
6261
# TODO: here, the dtype conversion is applied on the *whole expression*,

src/peft/tuners/adalora/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
180180
scaling = self.scaling[active_adapter]
181181
ranknum = self.ranknum[active_adapter] + 1e-5
182182

183-
x = x.to(lora_A.dtype)
183+
x = self._cast_input_dtype(x, lora_A.dtype)
184184
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
185185

186186
return result

src/peft/tuners/lora/aqlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
7575
requires_conversion = not torch.is_autocast_enabled()
7676
if requires_conversion:
7777
expected_dtype = result.dtype
78-
x = x.to(lora_A.weight.dtype)
78+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
7979

8080
output = lora_B(lora_A(dropout(x)))
8181
if requires_conversion:

src/peft/tuners/lora/awq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
7575
requires_conversion = not torch.is_autocast_enabled()
7676
if requires_conversion:
7777
expected_dtype = result.dtype
78-
x = x.to(lora_A.weight.dtype)
78+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
7979

8080
output = lora_B(lora_A(dropout(x)))
8181
if requires_conversion:

src/peft/tuners/lora/bnb.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def _mixed_batch_forward(
204204
requires_conversion = not torch.is_autocast_enabled()
205205
if requires_conversion:
206206
expected_dtype = result.dtype
207-
compute_dtype = lora_A.weight.dtype
208-
if x.dtype != compute_dtype:
209-
x = x.to(compute_dtype)
207+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
210208

211209
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
212210
# layer output
@@ -243,9 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
243241
requires_conversion = not torch.is_autocast_enabled()
244242
if requires_conversion:
245243
expected_dtype = result.dtype
246-
compute_dtype = lora_A.weight.dtype
247-
if x.dtype != compute_dtype:
248-
x = x.to(compute_dtype)
244+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
249245

250246
if not self.use_dora[active_adapter]:
251247
output = lora_B(lora_A(dropout(x))) * scaling
@@ -470,7 +466,7 @@ def _mixed_batch_forward(
470466
requires_conversion = not torch.is_autocast_enabled()
471467
if requires_conversion:
472468
expected_dtype = result.dtype
473-
x = x.to(lora_A.weight.dtype)
469+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
474470

475471
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
476472
# layer output
@@ -514,7 +510,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
514510
requires_conversion = not torch.is_autocast_enabled()
515511
if requires_conversion:
516512
expected_dtype = result.dtype
517-
x = x.to(lora_A.weight.dtype)
513+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
518514

519515
if not self.use_dora[active_adapter]:
520516
output = lora_B(lora_A(dropout(x))) * scaling

src/peft/tuners/lora/eetq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor):
7676
requires_conversion = not torch.is_autocast_enabled()
7777
if requires_conversion:
7878
expected_dtype = result.dtype
79-
x = x.to(lora_A.weight.dtype)
79+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
8080

8181
output = lora_B(lora_A(dropout(x)))
8282
if requires_conversion:

src/peft/tuners/lora/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
7575
requires_conversion = not torch.is_autocast_enabled()
7676
if requires_conversion:
7777
expected_dtype = result.dtype
78-
x = x.to(lora_A.weight.dtype)
78+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
7979

8080
output = lora_B(lora_A(dropout(x)))
8181
if requires_conversion:

0 commit comments

Comments
 (0)