2121from ..models .attention import AttentionModuleMixin
2222from ..models .attention_processor import Attention , MochiAttention
2323from ..utils import logging
24+ from ._common import (
25+ _ATTENTION_CLASSES ,
26+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS ,
27+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS ,
28+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS ,
29+ )
2430from .hooks import HookRegistry , ModelHook
2531
2632
2733logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2834
2935
3036_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
31- _ATTENTION_CLASSES = (Attention , MochiAttention )
32- _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks" , "transformer_blocks" , "single_transformer_blocks" )
33- _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks" ,)
34- _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks" , "transformer_blocks" )
3537
3638
3739@dataclass
@@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
6163 cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
6264 The range of timesteps to skip in the cross-attention layer. The attention computations will be
6365 conditionally skipped if the current timestep is within the specified range.
64- spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")` ):
66+ spatial_attention_block_identifiers (`Tuple[str, ...]`):
6567 The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
66- temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)` ):
68+ temporal_attention_block_identifiers (`Tuple[str, ...]`):
6769 The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
68- cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")` ):
70+ cross_attention_block_identifiers (`Tuple[str, ...]`):
6971 The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
7072 """
7173
@@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
7779 temporal_attention_timestep_skip_range : Tuple [int , int ] = (100 , 800 )
7880 cross_attention_timestep_skip_range : Tuple [int , int ] = (100 , 800 )
7981
80- spatial_attention_block_identifiers : Tuple [str , ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
81- temporal_attention_block_identifiers : Tuple [str , ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
82- cross_attention_block_identifiers : Tuple [str , ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
82+ spatial_attention_block_identifiers : Tuple [str , ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
83+ temporal_attention_block_identifiers : Tuple [str , ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
84+ cross_attention_block_identifiers : Tuple [str , ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8385
8486 current_timestep_callback : Callable [[], int ] = None
8587
0 commit comments