Skip to content

Commit dd675e8

Browse files
committed
refactor: streamline attention backend assignment in SkyReelsV2AttnProcessor
1 parent f90f80e commit dd675e8

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,10 @@ def apply_rotary_emb(
123123
query = apply_rotary_emb(query, *rotary_emb)
124124
key = apply_rotary_emb(key, *rotary_emb)
125125

126-
if self._attention_backend == "_native_flash-flash_varlen":
127-
if not self.is_cross_attention:
128-
self._attention_backend = "_native_flash"
129-
else:
130-
self._attention_backend = "flash_varlen"
126+
if not self.is_cross_attention:
127+
attention_backend = "_native_flash"
128+
else:
129+
attention_backend = "flash_varlen"
131130

132131
# I2V task
133132
hidden_states_img = None
@@ -145,7 +144,7 @@ def apply_rotary_emb(
145144
attn_mask=None,
146145
dropout_p=0.0,
147146
is_causal=False,
148-
backend=self._attention_backend,
147+
backend=attention_backend,
149148
)
150149
hidden_states_img = hidden_states_img.flatten(2, 3)
151150
hidden_states_img = hidden_states_img.type_as(query)
@@ -157,7 +156,7 @@ def apply_rotary_emb(
157156
attn_mask=attention_mask,
158157
dropout_p=0.0,
159158
is_causal=False,
160-
backend=self._attention_backend,
159+
backend=attention_backend,
161160
)
162161
hidden_states = hidden_states.flatten(2, 3)
163162
hidden_states = hidden_states.type_as(query)

0 commit comments

Comments
 (0)