Skip to content

Commit f90f80e

Browse files
committed
[feature] add optional attention mask to native flash attention and update backend handling in SkyReelsV2AttnProcessor
1 parent 5c6ce3c commit f90f80e

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,7 @@ def _native_flash_attention(
873873
query: torch.Tensor,
874874
key: torch.Tensor,
875875
value: torch.Tensor,
876+
attn_mask: Optional[torch.Tensor] = None,
876877
dropout_p: float = 0.0,
877878
is_causal: bool = False,
878879
scale: Optional[float] = None,
@@ -884,7 +885,7 @@ def _native_flash_attention(
884885
query=query,
885886
key=key,
886887
value=value,
887-
attn_mask=None, # not supported
888+
attn_mask=attn_mask,
888889
dropout_p=dropout_p,
889890
is_causal=is_causal,
890891
scale=scale,

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def apply_rotary_emb(
123123
query = apply_rotary_emb(query, *rotary_emb)
124124
key = apply_rotary_emb(key, *rotary_emb)
125125

126+
if self._attention_backend == "_native_flash-flash_varlen":
127+
if not self.is_cross_attention:
128+
self._attention_backend = "_native_flash"
129+
else:
130+
self._attention_backend = "flash_varlen"
131+
126132
# I2V task
127133
hidden_states_img = None
128134
if encoder_hidden_states_img is not None:
@@ -153,7 +159,6 @@ def apply_rotary_emb(
153159
is_causal=False,
154160
backend=self._attention_backend,
155161
)
156-
157162
hidden_states = hidden_states.flatten(2, 3)
158163
hidden_states = hidden_states.type_as(query)
159164

0 commit comments

Comments
 (0)