Skip to content

Commit fb8ad79

Browse files
committed
add test; reference PeftInputAutocastDisableHook in peft docs
1 parent 020e374 commit fb8ad79

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,7 @@ pipe.delete_adapters("toy")
221221
pipe.get_active_adapters()
222222
["pixel"]
223223
```
224+
225+
## PeftInputAutocastDisableHook
226+
227+
[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook

src/diffusers/hooks/layerwise_casting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def post_forward(self, module: torch.nn.Module, output):
7979

8080
class PeftInputAutocastDisableHook(ModelHook):
8181
r"""
82-
A hook that disables the casting of inputs to the module weight dtype for the forward pass. By default, PEFT casts
83-
the inputs to the weight dtype of the module, which can lead to precision loss.
82+
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
83+
casts the inputs to the weight dtype of the module, which can lead to precision loss.
8484
8585
The reasons for needing this are:
8686
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
@@ -94,7 +94,7 @@ class PeftInputAutocastDisableHook(ModelHook):
9494
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
9595
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are casted to
9696
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
97-
forward pass with PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
97+
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
9898
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
9999
"""
100100

tests/lora/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,3 +2157,64 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21572157

21582158
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
21592159
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
2160+
2161+
@require_peft_version_greater("0.14.0")
2162+
def test_layerwise_casting_peft_input_autocast_denoiser(self):
2163+
r"""
2164+
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
2165+
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
2166+
cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
2167+
In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version < 0.14.0,
2168+
this test will fail with the following error:
2169+
2170+
```
2171+
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
2172+
```
2173+
2174+
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
2175+
"""
2176+
2177+
from diffusers.hooks.layerwise_casting import (
2178+
_PEFT_AUTOCAST_DISABLE_HOOK,
2179+
DEFAULT_SKIP_MODULES_PATTERN,
2180+
SUPPORTED_PYTORCH_LAYERS,
2181+
apply_layerwise_casting,
2182+
)
2183+
2184+
storage_dtype = torch.float8_e4m3fn
2185+
compute_dtype = torch.float32
2186+
2187+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2188+
pipe = self.pipeline_class(**components)
2189+
pipe = pipe.to(torch_device, dtype=compute_dtype)
2190+
pipe.set_progress_bar_config(disable=None)
2191+
2192+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2193+
denoiser.add_adapter(denoiser_lora_config)
2194+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2195+
2196+
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2197+
if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
2198+
patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
2199+
2200+
apply_layerwise_casting(
2201+
denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
2202+
)
2203+
2204+
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
2205+
for name, module in denoiser.named_modules():
2206+
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
2207+
continue
2208+
dtype_to_check = storage_dtype
2209+
if any(re.search(pattern, name) for pattern in patterns_to_check):
2210+
dtype_to_check = compute_dtype
2211+
if getattr(module, "weight", None) is not None:
2212+
self.assertEqual(module.weight.dtype, dtype_to_check)
2213+
if getattr(module, "bias", None) is not None:
2214+
self.assertEqual(module.bias.dtype, dtype_to_check)
2215+
if isinstance(module, BaseTunerLayer):
2216+
self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
2217+
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
2218+
2219+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2220+
pipe(**inputs, generator=torch.manual_seed(0))[0]

0 commit comments

Comments
 (0)