File tree Expand file tree Collapse file tree 1 file changed +16
-1
lines changed
Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -1386,7 +1386,22 @@ def _templated_context_parallel_attention(
13861386 raise ValueError ("GQA is not yet supported for templated attention." )
13871387
13881388 # TODO: add support for unified attention with ring/ulysses degree both being > 1
1389- if _parallel_config .context_parallel_config .ring_degree > 1 :
1389+ if _parallel_config .context_parallel_config .ring_degree > 1 and _parallel_config .context_parallel_config .ulysses_degree > 1 :
1390+ return TemplatedUnifiedAttention (
1391+ query ,
1392+ key ,
1393+ value ,
1394+ attn_mask ,
1395+ dropout_p ,
1396+ is_causal ,
1397+ scale ,
1398+ enable_gqa ,
1399+ return_lse ,
1400+ forward_op ,
1401+ backward_op ,
1402+ _parallel_config ,
1403+ )
1404+ elif _parallel_config .context_parallel_config .ring_degree > 1 :
13901405 return TemplatedRingAttention .apply (
13911406 query ,
13921407 key ,
You can’t perform that action at this time.
0 commit comments