Skip to content

Commit fb66167

Browse files
committed
fixes
1 parent 76afc6a commit fb66167

File tree

12 files changed

+28
-19
lines changed

12 files changed

+28
-19
lines changed

src/diffusers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_import_structure = {
3030
"configuration_utils": ["ConfigMixin"],
31+
"hooks": [],
3132
"loaders": ["FromOriginalModelMixin"],
3233
"models": [],
3334
"pipelines": [],
@@ -77,6 +78,7 @@
7778
else:
7879
_import_structure["hooks"].extend(
7980
[
81+
"HookRegistry",
8082
"PyramidAttentionBroadcastConfig",
8183
"apply_pyramid_attention_broadcast",
8284
]
@@ -592,7 +594,7 @@
592594
except OptionalDependencyNotAvailable:
593595
from .utils.dummy_pt_objects import * # noqa F403
594596
else:
595-
from .hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
597+
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
596598
from .models import (
597599
AllegroTransformer3DModel,
598600
AsymmetricAutoencoderKL,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33

44
if is_torch_available():
5+
from .hooks import HookRegistry
56
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/hooks/hooks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def remove_hook(self, name: str) -> None:
147147
del self.hooks[name]
148148
self._hook_order.remove(name)
149149

150+
def reset_stateful_hooks(self):
151+
for hook_name in self._hook_order:
152+
hook = self.hooks[hook_name]
153+
if hook._is_stateful:
154+
hook.reset_state(self._module_ref)
155+
150156
@classmethod
151157
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
152158
if not hasattr(module, "_diffusers_hook"):

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from transformers import T5EncoderModel, T5Tokenizer
2525

2626
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27+
from ...hooks import HookRegistry
2728
from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro
2829
from ...models.embeddings import get_3d_rotary_pos_embed_allegro
29-
from ...models.hooks import reset_stateful_hooks
3030
from ...pipelines.pipeline_utils import DiffusionPipeline
3131
from ...schedulers import KarrasDiffusionSchedulers
3232
from ...utils import (
@@ -948,7 +948,7 @@ def __call__(
948948

949949
# Offload all models
950950
self.maybe_free_model_hooks()
951-
reset_stateful_hooks(self.transformer, recurse=True)
951+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
952952

953953
if not return_dict:
954954
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from transformers import T5EncoderModel, T5Tokenizer
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24+
from ...hooks import HookRegistry
2425
from ...loaders import CogVideoXLoraLoaderMixin
2526
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2627
from ...models.embeddings import get_3d_rotary_pos_embed
27-
from ...models.hooks import reset_stateful_hooks
2828
from ...pipelines.pipeline_utils import DiffusionPipeline
2929
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
3030
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -778,7 +778,7 @@ def __call__(
778778

779779
# Offload all models
780780
self.maybe_free_model_hooks()
781-
reset_stateful_hooks(self.transformer, recurse=True)
781+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
782782

783783
if not return_dict:
784784
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from transformers import T5EncoderModel, T5Tokenizer
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25+
from ...hooks import HookRegistry
2526
from ...loaders import CogVideoXLoraLoaderMixin
2627
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2728
from ...models.embeddings import get_3d_rotary_pos_embed
28-
from ...models.hooks import reset_stateful_hooks
2929
from ...pipelines.pipeline_utils import DiffusionPipeline
3030
from ...schedulers import KarrasDiffusionSchedulers
3131
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -831,7 +831,7 @@ def __call__(
831831

832832
# Offload all models
833833
self.maybe_free_model_hooks()
834-
reset_stateful_hooks(self.transformer, recurse=True)
834+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
835835

836836
if not return_dict:
837837
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from transformers import T5EncoderModel, T5Tokenizer
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25+
from ...hooks import HookRegistry
2526
from ...image_processor import PipelineImageInput
2627
from ...loaders import CogVideoXLoraLoaderMixin
2728
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2829
from ...models.embeddings import get_3d_rotary_pos_embed
29-
from ...models.hooks import reset_stateful_hooks
3030
from ...pipelines.pipeline_utils import DiffusionPipeline
3131
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
3232
from ...utils import (
@@ -892,7 +892,7 @@ def __call__(
892892

893893
# Offload all models
894894
self.maybe_free_model_hooks()
895-
reset_stateful_hooks(self.transformer, recurse=True)
895+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
896896

897897
if not return_dict:
898898
return (video,)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from transformers import T5EncoderModel, T5Tokenizer
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25+
from ...hooks import HookRegistry
2526
from ...loaders import CogVideoXLoraLoaderMixin
2627
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2728
from ...models.embeddings import get_3d_rotary_pos_embed
28-
from ...models.hooks import reset_stateful_hooks
2929
from ...pipelines.pipeline_utils import DiffusionPipeline
3030
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
3131
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -857,7 +857,7 @@ def __call__(
857857

858858
# Offload all models
859859
self.maybe_free_model_hooks()
860-
reset_stateful_hooks(self.transformer, recurse=True)
860+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
861861

862862
if not return_dict:
863863
return (video,)

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
T5TokenizerFast,
2727
)
2828

29+
from ...hooks import HookRegistry
2930
from ...image_processor import PipelineImageInput, VaeImageProcessor
3031
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
3132
from ...models import AutoencoderKL, FluxTransformer2DModel
32-
from ...models.hooks import reset_stateful_hooks
3333
from ...schedulers import FlowMatchEulerDiscreteScheduler
3434
from ...utils import (
3535
USE_PEFT_BACKEND,
@@ -971,7 +971,7 @@ def __call__(
971971

972972
# Offload all models
973973
self.maybe_free_model_hooks()
974-
reset_stateful_hooks(self.transformer, recurse=True)
974+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
975975

976976
if not return_dict:
977977
return (image,)

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23+
from ...hooks import HookRegistry
2324
from ...loaders import HunyuanVideoLoraLoaderMixin
2425
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
25-
from ...models.hooks import reset_stateful_hooks
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2828
from ...utils.torch_utils import randn_tensor
@@ -692,7 +692,7 @@ def __call__(
692692

693693
# Offload all models
694694
self.maybe_free_model_hooks()
695-
reset_stateful_hooks(self.transformer, recurse=True)
695+
HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks()
696696

697697
if not return_dict:
698698
return (video,)

0 commit comments

Comments
 (0)