Skip to content
Open
9 changes: 9 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,15 @@ def test_dpa_softmax(dtype, model_configs, model):
)


@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)


model_configs_mla = {
# test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!"

enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "Context parallelism does not support MLA with {cp_comm_type=}!"
], f"Context parallelism does not support MLA with {cp_comm_type=}!"

if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
), f"Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
if get_cudnn_version() < (9, 18, 0):
assert softmax_type == "vanilla" or qkv_format != "thd", (
f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with"
" qkv_format = 'thd'!"
)

args = [
is_training,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,22 +716,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if cudnn_version < (9, 18, 0):
logger.debug(
"Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN"
" version < 9.18",
softmax_type,
)
use_fused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
Expand Down
Loading