Skip to content

Commit 9a732f0

Browse files
committed
update
1 parent d9e7372 commit 9a732f0

File tree

4 files changed

+304
-42
lines changed

4 files changed

+304
-42
lines changed

src/diffusers/models/hooks.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Callable, Dict, Tuple
16+
from typing import Any, Dict, Tuple
1717

1818
import torch
1919

@@ -28,6 +28,7 @@ class ModelHook:
2828
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
2929
r"""
3030
Hook that is executed when a model is initialized.
31+
3132
Args:
3233
module (`torch.nn.Module`):
3334
The module attached to this hook.
@@ -37,6 +38,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3738
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
3839
r"""
3940
Hook that is executed just before the forward method of the model.
41+
4042
Args:
4143
module (`torch.nn.Module`):
4244
The module whose forward pass will be executed just after this event.
@@ -53,6 +55,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
5355
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
5456
r"""
5557
Hook that is executed just after the forward method of the model.
58+
5659
Args:
5760
module (`torch.nn.Module`):
5861
The module whose forward pass been executed just before this event.
@@ -66,15 +69,13 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
6669
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
6770
r"""
6871
Hook that is executed when the hook is detached from a module.
72+
6973
Args:
7074
module (`torch.nn.Module`):
7175
The module detached from this hook.
7276
"""
7377
return module
7478

75-
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
76-
return module
77-
7879

7980
class SequentialHook(ModelHook):
8081
r"""A hook that can contain several hooks and iterates through them at each event."""
@@ -102,52 +103,19 @@ def detach_hook(self, module):
102103
module = hook.detach_hook(module)
103104
return module
104105

105-
def reset_state(self, module):
106-
for hook in self.hooks:
107-
module = hook.reset_state(module)
108-
return module
109-
110-
111-
class FasterCacheHook(ModelHook):
112-
def __init__(
113-
self,
114-
skip_callback: Callable[[torch.nn.Module], bool],
115-
) -> None:
116-
super().__init__()
117-
118-
self.skip_callback = skip_callback
119-
120-
self.cache = None
121-
self._iteration = 0
122-
123-
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
124-
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
125-
126-
if self.cache is not None and self.skip_callback(module):
127-
output = self.cache
128-
else:
129-
output = module._old_forward(*args, **kwargs)
130-
131-
return module._diffusers_hook.post_forward(module, output)
132-
133-
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
134-
self.cache = output
135-
return output
136-
137-
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
138-
self.cache = None
139-
self._iteration = 0
140-
return module
141-
142106

143107
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
144108
r"""
145109
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
146110
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
111+
147112
<Tip warning={true}>
113+
148114
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
149115
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
116+
150117
</Tip>
118+
151119
Args:
152120
module (`torch.nn.Module`):
153121
The module to attach a hook to.
@@ -198,6 +166,7 @@ def new_forward(module, *args, **kwargs):
198166
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
199167
"""
200168
Removes any hook attached to a module via `add_hook_to_module`.
169+
201170
Args:
202171
module (`torch.nn.Module`):
203172
The module to attach a hook to.

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def __call__(
622622
)
623623
self._guidance_scale = guidance_scale
624624
self._attention_kwargs = attention_kwargs
625+
self._current_timestep = None
625626
self._interrupt = False
626627

627628
# 2. Default call parameters
@@ -700,6 +701,7 @@ def __call__(
700701
if self.interrupt:
701702
continue
702703

704+
self._current_timestep = t
703705
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
704706
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
705707

@@ -755,6 +757,8 @@ def __call__(
755757
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
756758
progress_bar.update()
757759

760+
self._current_timestep = None
761+
758762
if not output_type == "latent":
759763
# Discard any padding frames that were added for CogVideoX 1.5
760764
latents = latents[:, additional_frames:]

0 commit comments

Comments
 (0)