|
33 | 33 | Timesteps, |
34 | 34 | get_1d_rotary_pos_embed, |
35 | 35 | ) |
36 | | -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry |
| 36 | +from ..metadata import TransformerBlockMetadata, register_transformer_block |
37 | 37 | from ..modeling_outputs import Transformer2DModelOutput |
38 | 38 | from ..modeling_utils import ModelMixin |
39 | 39 | from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm |
@@ -311,7 +311,7 @@ def forward( |
311 | 311 | return conditioning, token_replace_emb |
312 | 312 |
|
313 | 313 |
|
314 | | -@TransformerBlockRegistry.register( |
| 314 | +@register_transformer_block( |
315 | 315 | metadata=TransformerBlockMetadata( |
316 | 316 | return_hidden_states_index=0, |
317 | 317 | return_encoder_hidden_states_index=None, |
@@ -496,7 +496,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
496 | 496 | return freqs_cos, freqs_sin |
497 | 497 |
|
498 | 498 |
|
499 | | -@TransformerBlockRegistry.register( |
| 499 | +@register_transformer_block( |
500 | 500 | metadata=TransformerBlockMetadata( |
501 | 501 | return_hidden_states_index=0, |
502 | 502 | return_encoder_hidden_states_index=1, |
@@ -578,7 +578,7 @@ def forward( |
578 | 578 | return hidden_states, encoder_hidden_states |
579 | 579 |
|
580 | 580 |
|
581 | | -@TransformerBlockRegistry.register( |
| 581 | +@register_transformer_block( |
582 | 582 | metadata=TransformerBlockMetadata( |
583 | 583 | return_hidden_states_index=0, |
584 | 584 | return_encoder_hidden_states_index=1, |
@@ -663,7 +663,7 @@ def forward( |
663 | 663 | return hidden_states, encoder_hidden_states |
664 | 664 |
|
665 | 665 |
|
666 | | -@TransformerBlockRegistry.register( |
| 666 | +@register_transformer_block( |
667 | 667 | metadata=TransformerBlockMetadata( |
668 | 668 | return_hidden_states_index=0, |
669 | 669 | return_encoder_hidden_states_index=1, |
@@ -749,7 +749,7 @@ def forward( |
749 | 749 | return hidden_states, encoder_hidden_states |
750 | 750 |
|
751 | 751 |
|
752 | | -@TransformerBlockRegistry.register( |
| 752 | +@register_transformer_block( |
753 | 753 | metadata=TransformerBlockMetadata( |
754 | 754 | return_hidden_states_index=0, |
755 | 755 | return_encoder_hidden_states_index=1, |
|
0 commit comments