Skip to content

Commit 9189815

Browse files
committed
update
1 parent ad6322a commit 9189815

File tree

11 files changed

+267
-151
lines changed

11 files changed

+267
-151
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Any, Callable, Dict, Type
18+
19+
20+
@dataclass
21+
class AttentionProcessorMetadata:
22+
skip_processor_output_fn: Callable[[Any], Any]
23+
24+
25+
@dataclass
26+
class TransformerBlockMetadata:
27+
return_hidden_states_index: int = None
28+
return_encoder_hidden_states_index: int = None
29+
30+
_cls: Type = None
31+
_cached_parameter_indices: Dict[str, int] = None
32+
33+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
34+
kwargs = kwargs or {}
35+
if identifier in kwargs:
36+
return kwargs[identifier]
37+
if self._cached_parameter_indices is not None:
38+
return args[self._cached_parameter_indices[identifier]]
39+
if self._cls is None:
40+
raise ValueError("Model class is not set for metadata.")
41+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
42+
parameters = parameters[1:] # skip `self`
43+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
44+
if identifier not in self._cached_parameter_indices:
45+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
46+
index = self._cached_parameter_indices[identifier]
47+
if index >= len(args):
48+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
49+
return args[index]
50+
51+
52+
class AttentionProcessorRegistry:
53+
_registry = {}
54+
# TODO(aryan): this is only required for the time being because we need to do the registrations
55+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56+
# import errors because of the models imported in this file.
57+
_is_registered = False
58+
59+
@classmethod
60+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
61+
cls._register()
62+
cls._registry[model_class] = metadata
63+
64+
@classmethod
65+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
66+
cls._register()
67+
if model_class not in cls._registry:
68+
raise ValueError(f"Model class {model_class} not registered.")
69+
return cls._registry[model_class]
70+
71+
@classmethod
72+
def _register(cls):
73+
if cls._is_registered:
74+
return
75+
cls._is_registered = True
76+
_register_attention_processors_metadata()
77+
78+
79+
class TransformerBlockRegistry:
80+
_registry = {}
81+
# TODO(aryan): this is only required for the time being because we need to do the registrations
82+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83+
# import errors because of the models imported in this file.
84+
_is_registered = False
85+
86+
@classmethod
87+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
88+
cls._register()
89+
metadata._cls = model_class
90+
cls._registry[model_class] = metadata
91+
92+
@classmethod
93+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
94+
cls._register()
95+
if model_class not in cls._registry:
96+
raise ValueError(f"Model class {model_class} not registered.")
97+
return cls._registry[model_class]
98+
99+
@classmethod
100+
def _register(cls):
101+
if cls._is_registered:
102+
return
103+
cls._is_registered = True
104+
_register_transformer_blocks_metadata()
105+
106+
107+
def _register_attention_processors_metadata():
108+
from ..models.attention_processor import AttnProcessor2_0
109+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
111+
# AttnProcessor2_0
112+
AttentionProcessorRegistry.register(
113+
model_class=AttnProcessor2_0,
114+
metadata=AttentionProcessorMetadata(
115+
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
116+
),
117+
)
118+
119+
# CogView4AttnProcessor
120+
AttentionProcessorRegistry.register(
121+
model_class=CogView4AttnProcessor,
122+
metadata=AttentionProcessorMetadata(
123+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
124+
),
125+
)
126+
127+
128+
def _register_transformer_blocks_metadata():
129+
from ..models.attention import BasicTransformerBlock
130+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
131+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
132+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
133+
from ..models.transformers.transformer_hunyuan_video import (
134+
HunyuanVideoSingleTransformerBlock,
135+
HunyuanVideoTokenReplaceSingleTransformerBlock,
136+
HunyuanVideoTokenReplaceTransformerBlock,
137+
HunyuanVideoTransformerBlock,
138+
)
139+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
140+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
141+
from ..models.transformers.transformer_wan import WanTransformerBlock
142+
143+
# BasicTransformerBlock
144+
TransformerBlockRegistry.register(
145+
model_class=BasicTransformerBlock,
146+
metadata=TransformerBlockMetadata(
147+
return_hidden_states_index=0,
148+
return_encoder_hidden_states_index=None,
149+
),
150+
)
151+
152+
# CogVideoX
153+
TransformerBlockRegistry.register(
154+
model_class=CogVideoXBlock,
155+
metadata=TransformerBlockMetadata(
156+
return_hidden_states_index=0,
157+
return_encoder_hidden_states_index=1,
158+
),
159+
)
160+
161+
# CogView4
162+
TransformerBlockRegistry.register(
163+
model_class=CogView4TransformerBlock,
164+
metadata=TransformerBlockMetadata(
165+
return_hidden_states_index=0,
166+
return_encoder_hidden_states_index=1,
167+
),
168+
)
169+
170+
# Flux
171+
TransformerBlockRegistry.register(
172+
model_class=FluxTransformerBlock,
173+
metadata=TransformerBlockMetadata(
174+
return_hidden_states_index=1,
175+
return_encoder_hidden_states_index=0,
176+
),
177+
)
178+
TransformerBlockRegistry.register(
179+
model_class=FluxSingleTransformerBlock,
180+
metadata=TransformerBlockMetadata(
181+
return_hidden_states_index=1,
182+
return_encoder_hidden_states_index=0,
183+
),
184+
)
185+
186+
# HunyuanVideo
187+
TransformerBlockRegistry.register(
188+
model_class=HunyuanVideoTransformerBlock,
189+
metadata=TransformerBlockMetadata(
190+
return_hidden_states_index=0,
191+
return_encoder_hidden_states_index=1,
192+
),
193+
)
194+
TransformerBlockRegistry.register(
195+
model_class=HunyuanVideoSingleTransformerBlock,
196+
metadata=TransformerBlockMetadata(
197+
return_hidden_states_index=0,
198+
return_encoder_hidden_states_index=1,
199+
),
200+
)
201+
TransformerBlockRegistry.register(
202+
model_class=HunyuanVideoTokenReplaceTransformerBlock,
203+
metadata=TransformerBlockMetadata(
204+
return_hidden_states_index=0,
205+
return_encoder_hidden_states_index=1,
206+
),
207+
)
208+
TransformerBlockRegistry.register(
209+
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
210+
metadata=TransformerBlockMetadata(
211+
return_hidden_states_index=0,
212+
return_encoder_hidden_states_index=1,
213+
),
214+
)
215+
216+
# LTXVideo
217+
TransformerBlockRegistry.register(
218+
model_class=LTXVideoTransformerBlock,
219+
metadata=TransformerBlockMetadata(
220+
return_hidden_states_index=0,
221+
return_encoder_hidden_states_index=None,
222+
),
223+
)
224+
225+
# Mochi
226+
TransformerBlockRegistry.register(
227+
model_class=MochiTransformerBlock,
228+
metadata=TransformerBlockMetadata(
229+
return_hidden_states_index=0,
230+
return_encoder_hidden_states_index=1,
231+
),
232+
)
233+
234+
# Wan
235+
TransformerBlockRegistry.register(
236+
model_class=WanTransformerBlock,
237+
metadata=TransformerBlockMetadata(
238+
return_hidden_states_index=0,
239+
return_encoder_hidden_states_index=None,
240+
),
241+
)
242+
243+
244+
# fmt: off
245+
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
246+
hidden_states = kwargs.get("hidden_states", None)
247+
if hidden_states is None and len(args) > 0:
248+
hidden_states = args[0]
249+
return hidden_states
250+
251+
252+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
253+
hidden_states = kwargs.get("hidden_states", None)
254+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
255+
if hidden_states is None and len(args) > 0:
256+
hidden_states = args[0]
257+
if encoder_hidden_states is None and len(args) > 1:
258+
encoder_hidden_states = args[1]
259+
return hidden_states, encoder_hidden_states
260+
261+
262+
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
263+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
264+
# fmt: on

src/diffusers/hooks/first_block_cache.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..utils import get_logger
2121
from ..utils.torch_utils import unwrap_module
2222
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23+
from ._helpers import TransformerBlockRegistry
2324
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
2425

2526

@@ -72,12 +73,7 @@ def __init__(self, state_manager: StateManager, threshold: float):
7273

7374
def initialize_hook(self, module):
7475
unwrapped_module = unwrap_module(module)
75-
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
76-
raise ValueError(
77-
f"Module {unwrapped_module} does not have any registered metadata. "
78-
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
79-
)
80-
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
76+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
8177
return module
8278

8379
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -156,12 +152,7 @@ def __init__(self, state_manager: StateManager, is_tail: bool = False):
156152

157153
def initialize_hook(self, module):
158154
unwrapped_module = unwrap_module(module)
159-
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
160-
raise ValueError(
161-
f"Module {unwrapped_module} does not have any registered metadata. "
162-
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
163-
)
164-
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
155+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
165156
return module
166157

167158
def new_forward(self, module: torch.nn.Module, *args, **kwargs):

src/diffusers/models/attention.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .metadata import TransformerBlockMetadata, register_transformer_block
2625
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
2726

2827

@@ -259,12 +258,6 @@ def forward(
259258

260259

261260
@maybe_allow_in_graph
262-
@register_transformer_block(
263-
metadata=TransformerBlockMetadata(
264-
return_hidden_states_index=0,
265-
return_encoder_hidden_states_index=None,
266-
)
267-
)
268261
class BasicTransformerBlock(nn.Module):
269262
r"""
270263
A basic Transformer block.

src/diffusers/models/metadata.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
2727
from ..cache_utils import CacheMixin
2828
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
29-
from ..metadata import TransformerBlockMetadata, register_transformer_block
3029
from ..modeling_outputs import Transformer2DModelOutput
3130
from ..modeling_utils import ModelMixin
3231
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -36,12 +35,6 @@
3635

3736

3837
@maybe_allow_in_graph
39-
@register_transformer_block(
40-
metadata=TransformerBlockMetadata(
41-
return_hidden_states_index=0,
42-
return_encoder_hidden_states_index=1,
43-
)
44-
)
4538
class CogVideoXBlock(nn.Module):
4639
r"""
4740
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.

0 commit comments

Comments
 (0)