Skip to content

Commit 903514f

Browse files
committed
make style
1 parent af51f5d commit 903514f

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,34 @@ class PyramidAttentionBroadcastConfig:
4040
4141
Args:
4242
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
43-
The number of blocks to skip in the spatial attention layer. If `None`, the spatial attention layer
44-
computations will not be skipped.
43+
The number of times a specific spatial attention broadcast is skipped before computing the attention states
44+
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
45+
old attention states will be re-used) before computing the new attention states again.
4546
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
46-
The number of blocks to skip in the temporal attention layer. If `None`, the temporal attention layer
47-
computations will not be skipped.
47+
The number of times a specific temporal attention broadcast is skipped before computing the attention
48+
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
49+
(i.e., old attention states will be re-used) before computing the new attention states again.
4850
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
49-
The number of blocks to skip in the cross-attention layer. If `None`, the cross-attention layer computations
50-
will not be skipped.
51+
The number of times a specific cross-attention broadcast is skipped before computing the attention states
52+
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
53+
old attention states will be re-used) before computing the new attention states again.
5154
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
52-
The range of timesteps to skip in the spatial attention layer. The attention computations will be skipped
53-
if the current timestep is within the specified range.
55+
The range of timesteps to skip in the spatial attention layer. The attention computations will be
56+
conditionally skipped if the current timestep is within the specified range.
5457
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
55-
The range of timesteps to skip in the temporal attention layer. The attention computations will be skipped
56-
if the current timestep is within the specified range.
58+
The range of timesteps to skip in the temporal attention layer. The attention computations will be
59+
conditionally skipped if the current timestep is within the specified range.
5760
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
58-
The range of timesteps to skip in the cross-attention layer. The attention computations will be skipped if
59-
the current timestep is within the specified range.
61+
The range of timesteps to skip in the cross-attention layer. The attention computations will be
62+
conditionally skipped if the current timestep is within the specified range.
6063
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
6164
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
6265
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
6366
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
6467
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
6568
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
6669
"""
70+
6771
spatial_attention_block_skip_range: Optional[int] = None
6872
temporal_attention_block_skip_range: Optional[int] = None
6973
cross_attention_block_skip_range: Optional[int] = None
@@ -86,6 +90,7 @@ class PyramidAttentionBroadcastState:
8690
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
8791
called before starting a new inference forward pass for PAB to work correctly.
8892
"""
93+
8994
def __init__(self) -> None:
9095
self.iteration = 0
9196

@@ -101,6 +106,7 @@ class nnModulePAB(Protocol):
101106
_pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`):
102107
The state of Pyramid Attention Broadcast.
103108
"""
109+
104110
_pyramid_attention_broadcast_state: PyramidAttentionBroadcastState
105111

106112

0 commit comments

Comments
 (0)