Skip to content

Commit 64e6c9c

Browse files
committed
deinitialize should raise an error
1 parent bde103c commit 64e6c9c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/hooks/layerwise_upcasting.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def initialize_hook(self, module: torch.nn.Module):
5353
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
5454
return module
5555

56+
def deinitalize_hook(self, module: torch.nn.Module):
57+
raise NotImplementedError(
58+
"LayerwiseUpcastingHook does not support deinitalization. A model once enabled with layerwise upcasting will "
59+
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
60+
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
61+
"be re-initialized and loaded in the original dtype."
62+
)
63+
5664
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
5765
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
5866
return args, kwargs

0 commit comments

Comments
 (0)