@@ -40,30 +40,34 @@ class PyramidAttentionBroadcastConfig:
4040
4141 Args:
4242 spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
43- The number of blocks to skip in the spatial attention layer. If `None`, the spatial attention layer
44- computations will not be skipped.
43+ The number of times a specific spatial attention broadcast is skipped before computing the attention states
44+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
45+ old attention states will be re-used) before computing the new attention states again.
4546 temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
46- The number of blocks to skip in the temporal attention layer. If `None`, the temporal attention layer
47- computations will not be skipped.
47+ The number of times a specific temporal attention broadcast is skipped before computing the attention
48+ states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
49+ (i.e., old attention states will be re-used) before computing the new attention states again.
4850 cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
49- The number of blocks to skip in the cross-attention layer. If `None`, the cross-attention layer computations
50- will not be skipped.
51+ The number of times a specific cross-attention broadcast is skipped before computing the attention states
52+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
53+ old attention states will be re-used) before computing the new attention states again.
5154 spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
52- The range of timesteps to skip in the spatial attention layer. The attention computations will be skipped
53- if the current timestep is within the specified range.
55+ The range of timesteps to skip in the spatial attention layer. The attention computations will be
56+ conditionally skipped if the current timestep is within the specified range.
5457 temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
55- The range of timesteps to skip in the temporal attention layer. The attention computations will be skipped
56- if the current timestep is within the specified range.
58+ The range of timesteps to skip in the temporal attention layer. The attention computations will be
59+ conditionally skipped if the current timestep is within the specified range.
5760 cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
58- The range of timesteps to skip in the cross-attention layer. The attention computations will be skipped if
59- the current timestep is within the specified range.
61+ The range of timesteps to skip in the cross-attention layer. The attention computations will be
62+ conditionally skipped if the current timestep is within the specified range.
6063 spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
6164 The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
6265 temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
6366 The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
6467 cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
6568 The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
6669 """
70+
6771 spatial_attention_block_skip_range : Optional [int ] = None
6872 temporal_attention_block_skip_range : Optional [int ] = None
6973 cross_attention_block_skip_range : Optional [int ] = None
@@ -86,6 +90,7 @@ class PyramidAttentionBroadcastState:
8690 The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
8791 called before starting a new inference forward pass for PAB to work correctly.
8892 """
93+
8994 def __init__ (self ) -> None :
9095 self .iteration = 0
9196
@@ -101,6 +106,7 @@ class nnModulePAB(Protocol):
101106 _pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`):
102107 The state of Pyramid Attention Broadcast.
103108 """
109+
104110 _pyramid_attention_broadcast_state : PyramidAttentionBroadcastState
105111
106112
0 commit comments