Skip to content

Commit 62b5b8d

Browse files
committed
update
1 parent d974401 commit 62b5b8d

File tree

10 files changed

+101
-42
lines changed

10 files changed

+101
-42
lines changed

src/diffusers/models/hooks.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
2222
class ModelHook:
2323
r"""
24-
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
25-
with PyTorch existing hooks is that they get passed along the kwargs.
24+
A hook that contains callbacks to be executed just before and after the forward method of a model.
2625
"""
2726

27+
_is_stateful = False
28+
2829
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
2930
r"""
3031
Hook that is executed when a model is initialized.
@@ -78,6 +79,10 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7879
"""
7980
return module
8081

82+
def reset_state(self, module: torch.nn.Module):
83+
if self._is_stateful:
84+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
85+
8186

8287
class SequentialHook(ModelHook):
8388
r"""A hook that can contain several hooks and iterates through them at each event."""
@@ -105,8 +110,13 @@ def detach_hook(self, module):
105110
module = hook.detach_hook(module)
106111
return module
107112

113+
def reset_state(self, module):
114+
for hook in self.hooks:
115+
if hook._is_stateful:
116+
hook.reset_state(module)
117+
108118

109-
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
119+
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
110120
r"""
111121
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
112122
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
@@ -199,3 +209,21 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
199209
remove_hook_from_module(child, recurse)
200210

201211
return module
212+
213+
214+
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
215+
"""
216+
Resets the state of all stateful hooks attached to a module.
217+
218+
Args:
219+
module (`torch.nn.Module`):
220+
The module to reset the stateful hooks from.
221+
"""
222+
if hasattr(module, "_diffusers_hook") and (
223+
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
224+
):
225+
module._diffusers_hook.reset_state(module)
226+
227+
if recurse:
228+
for child in module.children():
229+
reset_stateful_hooks(child, recurse)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...loaders import CogVideoXLoraLoaderMixin
2525
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2626
from ...models.embeddings import get_3d_rotary_pos_embed
27+
from ...models.hooks import reset_stateful_hooks
2728
from ...pipelines.pipeline_utils import DiffusionPipeline
2829
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
2930
from ...utils import logging, replace_example_docstring
@@ -769,6 +770,7 @@ def __call__(
769770

770771
# Offload all models
771772
self.maybe_free_model_hooks()
773+
reset_stateful_hooks(self.transformer, recurse=True)
772774

773775
if not return_dict:
774776
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...loaders import CogVideoXLoraLoaderMixin
2626
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2727
from ...models.embeddings import get_3d_rotary_pos_embed
28+
from ...models.hooks import reset_stateful_hooks
2829
from ...pipelines.pipeline_utils import DiffusionPipeline
2930
from ...schedulers import KarrasDiffusionSchedulers
3031
from ...utils import logging, replace_example_docstring
@@ -822,6 +823,7 @@ def __call__(
822823

823824
# Offload all models
824825
self.maybe_free_model_hooks()
826+
reset_stateful_hooks(self.transformer, recurse=True)
825827

826828
if not return_dict:
827829
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ...loaders import CogVideoXLoraLoaderMixin
2727
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2828
from ...models.embeddings import get_3d_rotary_pos_embed
29+
from ...models.hooks import reset_stateful_hooks
2930
from ...pipelines.pipeline_utils import DiffusionPipeline
3031
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
3132
from ...utils import (
@@ -882,6 +883,7 @@ def __call__(
882883

883884
# Offload all models
884885
self.maybe_free_model_hooks()
886+
reset_stateful_hooks(self.transformer, recurse=True)
885887

886888
if not return_dict:
887889
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...loaders import CogVideoXLoraLoaderMixin
2626
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2727
from ...models.embeddings import get_3d_rotary_pos_embed
28+
from ...models.hooks import reset_stateful_hooks
2829
from ...pipelines.pipeline_utils import DiffusionPipeline
2930
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
3031
from ...utils import logging, replace_example_docstring
@@ -848,6 +849,7 @@ def __call__(
848849

849850
# Offload all models
850851
self.maybe_free_model_hooks()
852+
reset_stateful_hooks(self.transformer, recurse=True)
851853

852854
if not return_dict:
853855
return (video,)

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
from ...image_processor import PipelineImageInput, VaeImageProcessor
3030
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31-
from ...models.autoencoders import AutoencoderKL
32-
from ...models.transformers import FluxTransformer2DModel
31+
from ...models import AutoencoderKL, FluxTransformer2DModel
32+
from ...models.hooks import reset_stateful_hooks
3333
from ...schedulers import FlowMatchEulerDiscreteScheduler
3434
from ...utils import (
3535
USE_PEFT_BACKEND,
@@ -953,6 +953,7 @@ def __call__(
953953

954954
# Offload all models
955955
self.maybe_free_model_hooks()
956+
reset_stateful_hooks(self.transformer, recurse=True)
956957

957958
if not return_dict:
958959
return (image,)

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import HunyuanVideoLoraLoaderMixin
2424
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
25+
from ...models.hooks import reset_stateful_hooks
2526
from ...schedulers import FlowMatchEulerDiscreteScheduler
2627
from ...utils import logging, replace_example_docstring
2728
from ...utils.torch_utils import randn_tensor
@@ -573,6 +574,7 @@ def __call__(
573574

574575
self._guidance_scale = guidance_scale
575576
self._attention_kwargs = attention_kwargs
577+
self._current_timestep = None
576578
self._interrupt = False
577579

578580
device = self._execution_device
@@ -640,6 +642,7 @@ def __call__(
640642
if self.interrupt:
641643
continue
642644

645+
self._current_timestep = t
643646
latent_model_input = latents.to(transformer_dtype)
644647
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645648
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -671,6 +674,8 @@ def __call__(
671674
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
672675
progress_bar.update()
673676

677+
self._current_timestep = None
678+
674679
if not output_type == "latent":
675680
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
676681
video = self.vae.decode(latents, return_dict=False)[0]
@@ -680,6 +685,7 @@ def __call__(
680685

681686
# Offload all models
682687
self.maybe_free_model_hooks()
688+
reset_stateful_hooks(self.transformer, recurse=True)
683689

684690
if not return_dict:
685691
return (video,)

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2727
from ...models import AutoencoderKL, LatteTransformer3DModel
28+
from ...models.hooks import reset_stateful_hooks
2829
from ...pipelines.pipeline_utils import DiffusionPipeline
2930
from ...schedulers import KarrasDiffusionSchedulers
3031
from ...utils import (
@@ -848,6 +849,7 @@ def __call__(
848849

849850
# Offload all models
850851
self.maybe_free_model_hooks()
852+
reset_stateful_hooks(self.transformer, recurse=True)
851853

852854
if not return_dict:
853855
return (video,)

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import Mochi1LoraLoaderMixin
24-
from ...models.autoencoders import AutoencoderKL
25-
from ...models.transformers import MochiTransformer3DModel
24+
from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel
25+
from ...models.hooks import reset_stateful_hooks
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import (
2828
is_torch_xla_available,
@@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
184184
def __init__(
185185
self,
186186
scheduler: FlowMatchEulerDiscreteScheduler,
187-
vae: AutoencoderKL,
187+
vae: AutoencoderKLHunyuanVideo,
188188
text_encoder: T5EncoderModel,
189189
tokenizer: T5TokenizerFast,
190190
transformer: MochiTransformer3DModel,
@@ -604,6 +604,7 @@ def __call__(
604604

605605
self._guidance_scale = guidance_scale
606606
self._attention_kwargs = attention_kwargs
607+
self._current_timestep = None
607608
self._interrupt = False
608609

609610
# 2. Define call parameters
@@ -673,6 +674,9 @@ def __call__(
673674
if self.interrupt:
674675
continue
675676

677+
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
678+
# to make sure we're using the correct non-reversed timestep values.
679+
self._current_timestep = 1000 - t
676680
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
677681
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
678682
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
@@ -718,6 +722,8 @@ def __call__(
718722
if XLA_AVAILABLE:
719723
xm.mark_step()
720724

725+
self._current_timestep = None
726+
721727
if output_type == "latent":
722728
video = latents
723729
else:
@@ -741,6 +747,7 @@ def __call__(
741747

742748
# Offload all models
743749
self.maybe_free_model_hooks()
750+
reset_stateful_hooks(self.transformer, recurse=True)
744751

745752
if not return_dict:
746753
return (video,)

0 commit comments

Comments
 (0)