Skip to content

Commit 5618a7d

Browse files
committed
-
1 parent 4a21cf8 commit 5618a7d

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,22 @@ def _templated_context_parallel_attention(
13861386
raise ValueError("GQA is not yet supported for templated attention.")
13871387

13881388
# TODO: add support for unified attention with ring/ulysses degree both being > 1
1389-
if _parallel_config.context_parallel_config.ring_degree > 1:
1389+
if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1:
1390+
return TemplatedUnifiedAttention(
1391+
query,
1392+
key,
1393+
value,
1394+
attn_mask,
1395+
dropout_p,
1396+
is_causal,
1397+
scale,
1398+
enable_gqa,
1399+
return_lse,
1400+
forward_op,
1401+
backward_op,
1402+
_parallel_config,
1403+
)
1404+
elif _parallel_config.context_parallel_config.ring_degree > 1:
13901405
return TemplatedRingAttention.apply(
13911406
query,
13921407
key,

0 commit comments

Comments
 (0)