Skip to content

Commit c76e1cc

Browse files
committed
update
1 parent 315e357 commit c76e1cc

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

src/diffusers/hooks/first_block_cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from ..utils import get_logger
21+
from ..utils.torch_utils import unwrap_module
2122
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
2223
from ._helpers import TransformerBlockRegistry
2324
from .hooks import BaseMarkedState, HookRegistry, ModelHook
@@ -71,7 +72,7 @@ def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
7172
self._metadata = None
7273

7374
def initialize_hook(self, module):
74-
self._metadata = TransformerBlockRegistry.get(module.__class__)
75+
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
7576
return module
7677

7778
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -147,7 +148,7 @@ def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
147148
self._metadata = None
148149

149150
def initialize_hook(self, module):
150-
self._metadata = TransformerBlockRegistry.get(module.__class__)
151+
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
151152
return module
152153

153154
def new_forward(self, module: torch.nn.Module, *args, **kwargs):

src/diffusers/hooks/hooks.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from ..utils.logging import get_logger
21+
from ..utils.torch_utils import unwrap_module
2122

2223

2324
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -47,7 +48,7 @@ def get_current_state(self) -> "BaseMarkedState":
4748
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
4849
return self._state_cache[self._mark_name]
4950

50-
def mark_batch(self, name: str) -> None:
51+
def mark_state(self, name: str) -> None:
5152
self._mark_name = name
5253

5354
def reset(self, *args, **kwargs) -> None:
@@ -59,7 +60,7 @@ def reset(self, *args, **kwargs) -> None:
5960
def __getattribute__(self, name):
6061
if name in (
6162
"get_current_state",
62-
"mark_batch",
63+
"mark_state",
6364
"reset",
6465
"_init_args",
6566
"_init_kwargs",
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
7475
def __setattr__(self, name, value):
7576
if name in (
7677
"get_current_state",
77-
"mark_batch",
78+
"mark_state",
7879
"reset",
7980
"_init_args",
8081
"_init_kwargs",
@@ -164,11 +165,11 @@ def reset_state(self, module: torch.nn.Module):
164165
return module
165166

166167
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
167-
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them.
168+
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
168169
for attr_name in dir(self):
169170
attr = getattr(self, attr_name)
170171
if isinstance(attr, BaseMarkedState):
171-
attr.mark_batch(name)
172+
attr.mark_state(name)
172173
return module
173174

174175

@@ -283,9 +284,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None:
283284
hook.reset_state(self._module_ref)
284285

285286
if recurse:
286-
for module_name, module in self._module_ref.named_modules():
287+
for module_name, module in unwrap_module(self._module_ref).named_modules():
287288
if module_name == "":
288289
continue
290+
module = unwrap_module(module)
289291
if hasattr(module, "_diffusers_hook"):
290292
module._diffusers_hook.reset_stateful_hooks(recurse=False)
291293

@@ -301,9 +303,10 @@ def _mark_state(self, name: str) -> None:
301303
if hook._is_stateful:
302304
hook._mark_state(self._module_ref, name)
303305

304-
for module_name, module in self._module_ref.named_modules():
306+
for module_name, module in unwrap_module(self._module_ref).named_modules():
305307
if module_name == "":
306308
continue
309+
module = unwrap_module(module)
307310
if hasattr(module, "_diffusers_hook"):
308311
module._diffusers_hook._mark_state(name)
309312

src/diffusers/utils/torch_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
9090
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
9191

9292

93+
def unwrap_module(module):
94+
"""Unwraps a module if it was compiled with torch.compile()"""
95+
return module._orig_mod if is_compiled_module(module) else module
96+
97+
9398
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
9499
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
95100

0 commit comments

Comments
 (0)