Skip to content

Commit b8317da

Browse files
committed
remove central registry based on review
1 parent a5fe2bd commit b8317da

File tree

10 files changed

+42
-41
lines changed

10 files changed

+42
-41
lines changed

src/diffusers/hooks/first_block_cache.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919

20-
from ..models.metadata import TransformerBlockRegistry
2120
from ..utils import get_logger
2221
from ..utils.torch_utils import unwrap_module
2322
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
@@ -72,7 +71,13 @@ def __init__(self, state_manager: StateManager, threshold: float):
7271
self._metadata = None
7372

7473
def initialize_hook(self, module):
75-
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
74+
unwrapped_module = unwrap_module(module)
75+
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
76+
raise ValueError(
77+
f"Module {unwrapped_module} does not have any registered metadata. "
78+
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
79+
)
80+
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
7681
return module
7782

7883
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -150,7 +155,13 @@ def __init__(self, state_manager: StateManager, is_tail: bool = False):
150155
self._metadata = None
151156

152157
def initialize_hook(self, module):
153-
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
158+
unwrapped_module = unwrap_module(module)
159+
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
160+
raise ValueError(
161+
f"Module {unwrapped_module} does not have any registered metadata. "
162+
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
163+
)
164+
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
154165
return module
155166

156167
def new_forward(self, module: torch.nn.Module, *args, **kwargs):

src/diffusers/models/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .metadata import TransformerBlockMetadata, TransformerBlockRegistry
25+
from .metadata import TransformerBlockMetadata, register_transformer_block
2626
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
2727

2828

@@ -259,7 +259,7 @@ def forward(
259259

260260

261261
@maybe_allow_in_graph
262-
@TransformerBlockRegistry.register(
262+
@register_transformer_block(
263263
metadata=TransformerBlockMetadata(
264264
return_hidden_states_index=0,
265265
return_encoder_hidden_states_index=None,

src/diffusers/models/metadata.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,10 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None)
4444
return args[index]
4545

4646

47-
class TransformerBlockRegistry:
48-
_registry = {}
49-
50-
@classmethod
51-
def register(cls, metadata: TransformerBlockMetadata):
52-
def inner(model_class: Type):
53-
metadata._cls = model_class
54-
cls._registry[model_class] = metadata
55-
return model_class
56-
57-
return inner
58-
59-
@classmethod
60-
def get(cls, model_class: Type) -> TransformerBlockMetadata:
61-
if model_class not in cls._registry:
62-
raise ValueError(f"Model class {model_class} not registered.")
63-
return cls._registry[model_class]
47+
def register_transformer_block(metadata: TransformerBlockMetadata):
48+
def inner(model_class: Type):
49+
metadata._cls = model_class
50+
model_class._diffusers_transformer_block_metadata = metadata
51+
return model_class
52+
53+
return inner

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
2727
from ..cache_utils import CacheMixin
2828
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
29-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
29+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3030
from ..modeling_outputs import Transformer2DModelOutput
3131
from ..modeling_utils import ModelMixin
3232
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -36,7 +36,7 @@
3636

3737

3838
@maybe_allow_in_graph
39-
@TransformerBlockRegistry.register(
39+
@register_transformer_block(
4040
metadata=TransformerBlockMetadata(
4141
return_hidden_states_index=0,
4242
return_encoder_hidden_states_index=1,

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..attention_processor import Attention
2727
from ..cache_utils import CacheMixin
2828
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
29-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
29+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3030
from ..modeling_outputs import Transformer2DModelOutput
3131
from ..modeling_utils import ModelMixin
3232
from ..normalization import AdaLayerNormContinuous
@@ -456,7 +456,7 @@ def __call__(
456456

457457

458458
@maybe_allow_in_graph
459-
@TransformerBlockRegistry.register(
459+
@register_transformer_block(
460460
metadata=TransformerBlockMetadata(
461461
return_hidden_states_index=0,
462462
return_encoder_hidden_states_index=1,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from ..cache_utils import CacheMixin
3636
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
37+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3838
from ..modeling_outputs import Transformer2DModelOutput
3939
from ..modeling_utils import ModelMixin
4040
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -44,7 +44,7 @@
4444

4545

4646
@maybe_allow_in_graph
47-
@TransformerBlockRegistry.register(
47+
@register_transformer_block(
4848
metadata=TransformerBlockMetadata(
4949
return_hidden_states_index=1,
5050
return_encoder_hidden_states_index=0,
@@ -116,7 +116,7 @@ def forward(
116116

117117

118118
@maybe_allow_in_graph
119-
@TransformerBlockRegistry.register(
119+
@register_transformer_block(
120120
metadata=TransformerBlockMetadata(
121121
return_hidden_states_index=1,
122122
return_encoder_hidden_states_index=0,

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
Timesteps,
3434
get_1d_rotary_pos_embed,
3535
)
36-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
36+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3737
from ..modeling_outputs import Transformer2DModelOutput
3838
from ..modeling_utils import ModelMixin
3939
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
@@ -311,7 +311,7 @@ def forward(
311311
return conditioning, token_replace_emb
312312

313313

314-
@TransformerBlockRegistry.register(
314+
@register_transformer_block(
315315
metadata=TransformerBlockMetadata(
316316
return_hidden_states_index=0,
317317
return_encoder_hidden_states_index=None,
@@ -496,7 +496,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
496496
return freqs_cos, freqs_sin
497497

498498

499-
@TransformerBlockRegistry.register(
499+
@register_transformer_block(
500500
metadata=TransformerBlockMetadata(
501501
return_hidden_states_index=0,
502502
return_encoder_hidden_states_index=1,
@@ -578,7 +578,7 @@ def forward(
578578
return hidden_states, encoder_hidden_states
579579

580580

581-
@TransformerBlockRegistry.register(
581+
@register_transformer_block(
582582
metadata=TransformerBlockMetadata(
583583
return_hidden_states_index=0,
584584
return_encoder_hidden_states_index=1,
@@ -663,7 +663,7 @@ def forward(
663663
return hidden_states, encoder_hidden_states
664664

665665

666-
@TransformerBlockRegistry.register(
666+
@register_transformer_block(
667667
metadata=TransformerBlockMetadata(
668668
return_hidden_states_index=0,
669669
return_encoder_hidden_states_index=1,
@@ -749,7 +749,7 @@ def forward(
749749
return hidden_states, encoder_hidden_states
750750

751751

752-
@TransformerBlockRegistry.register(
752+
@register_transformer_block(
753753
metadata=TransformerBlockMetadata(
754754
return_hidden_states_index=0,
755755
return_encoder_hidden_states_index=1,

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..attention_processor import Attention
2929
from ..cache_utils import CacheMixin
3030
from ..embeddings import PixArtAlphaTextProjection
31-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
31+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3232
from ..modeling_outputs import Transformer2DModelOutput
3333
from ..modeling_utils import ModelMixin
3434
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -197,7 +197,7 @@ def forward(
197197

198198

199199
@maybe_allow_in_graph
200-
@TransformerBlockRegistry.register(
200+
@register_transformer_block(
201201
metadata=TransformerBlockMetadata(
202202
return_hidden_states_index=0,
203203
return_encoder_hidden_states_index=None,

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
2828
from ..cache_utils import CacheMixin
2929
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
30-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
30+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3131
from ..modeling_outputs import Transformer2DModelOutput
3232
from ..modeling_utils import ModelMixin
3333
from ..normalization import AdaLayerNormContinuous, RMSNorm
@@ -117,7 +117,7 @@ def forward(
117117

118118

119119
@maybe_allow_in_graph
120-
@TransformerBlockRegistry.register(
120+
@register_transformer_block(
121121
metadata=TransformerBlockMetadata(
122122
return_hidden_states_index=0,
123123
return_encoder_hidden_states_index=1,

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..attention_processor import Attention
2828
from ..cache_utils import CacheMixin
2929
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
30-
from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry
30+
from ..metadata import TransformerBlockMetadata, register_transformer_block
3131
from ..modeling_outputs import Transformer2DModelOutput
3232
from ..modeling_utils import ModelMixin
3333
from ..normalization import FP32LayerNorm
@@ -222,7 +222,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
222222

223223

224224
@maybe_allow_in_graph
225-
@TransformerBlockRegistry.register(
225+
@register_transformer_block(
226226
TransformerBlockMetadata(
227227
return_hidden_states_index=0,
228228
return_encoder_hidden_states_index=None,

0 commit comments

Comments
 (0)