Skip to content

Commit fe1a9a4

Browse files
committed
disable peft input autocast
1 parent ea76880 commit fe1a9a4

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

src/diffusers/hooks/layerwise_casting.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,28 @@
1717

1818
import torch
1919

20-
from ..utils import get_logger
20+
from ..utils import get_logger, is_peft_available
2121
from .hooks import HookRegistry, ModelHook
2222

2323

24+
if is_peft_available():
25+
from peft.helpers import disable_lora_input_dtype_casting
26+
from peft.tuners.loha.layer import LoHaLayer
27+
from peft.tuners.lokr.layer import LoKrLayer
28+
from peft.tuners.lora.layer import LoraLayer
29+
from peft.tuners.tuners_utils import BaseTunerLayer
30+
31+
PEFT_ADAPTER_LAYERS_NAMES = ()
32+
for layer_cls in (LoHaLayer, LoKrLayer, LoraLayer):
33+
PEFT_ADAPTER_LAYERS_NAMES += tuple(layer_cls.adapter_layer_names)
34+
35+
2436
logger = get_logger(__name__) # pylint: disable=invalid-name
2537

2638

2739
# fmt: off
40+
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
41+
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
2842
SUPPORTED_PYTORCH_LAYERS = (
2943
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3044
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -70,6 +84,32 @@ def post_forward(self, module: torch.nn.Module, output):
7084
return output
7185

7286

87+
class PeftInputAutocastDisableHook(ModelHook):
88+
r"""
89+
A hook that disables the casting of inputs to the module weight dtype for the forward pass. By default, PEFT casts
90+
the inputs to the weight dtype of the module, which can lead to precision loss.
91+
92+
The reasons for needing this are:
93+
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
94+
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
95+
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
96+
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
97+
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
98+
hoping to achieve:
99+
1. Making forward implementations independent of device/dtype casting operations as much as possible.
100+
2. Peforming inference without losing information from casting to different precisions. With the current
101+
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
102+
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are casted to
103+
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
104+
forward pass with PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
105+
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
106+
"""
107+
108+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
109+
with disable_lora_input_dtype_casting(module, disable=True):
110+
return self.fn_ref.original_forward(*args, **kwargs)
111+
112+
73113
def apply_layerwise_casting(
74114
module: torch.nn.Module,
75115
storage_dtype: torch.dtype,
@@ -134,6 +174,7 @@ def apply_layerwise_casting(
134174
skip_modules_classes,
135175
non_blocking,
136176
)
177+
_disable_peft_input_autocast(module)
137178

138179

139180
def _apply_layerwise_casting(
@@ -188,4 +229,13 @@ def apply_layerwise_casting_hook(
188229
"""
189230
registry = HookRegistry.check_if_exists_or_initialize(module)
190231
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
191-
registry.register_hook(hook, "layerwise_casting")
232+
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
233+
234+
235+
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
236+
for submodule in module.modules():
237+
if not isinstance(submodule, BaseTunerLayer):
238+
continue
239+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
240+
hook = PeftInputAutocastDisableHook()
241+
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)

0 commit comments

Comments
 (0)