Skip to content

Commit 32e2d1d

Browse files
authored
[ROCm] align the softmax aux shape with NVTE upstream (#371)
1 parent 87fece2 commit 32e2d1d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ def abstract(
357357
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
358358
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
359359
elif backend == NVTE_Fused_Attn_Backend.NVTE_CK:
360-
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
360+
if config.qkv_layout.is_thd():
361+
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
362+
else:
363+
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
361364
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
362365
else:
363366
raise ValueError(f"Unsupported {backend=}")

0 commit comments

Comments
 (0)