|
54 | 54 | PLATFORM_SUPPORTS_CUDNN_ATTENTION, |
55 | 55 | tf32_on_and_off, |
56 | 56 | tf32_enabled, |
| 57 | + math_sdp_precision, |
57 | 58 | ) |
58 | 59 |
|
59 | 60 | if TEST_FAIRSEQ: |
@@ -3413,6 +3414,7 @@ def test_mem_eff_backwards_determinism(self, device): |
3413 | 3414 | ) |
3414 | 3415 | @parametrize("scale", [None, "l1"]) |
3415 | 3416 | @tf32_enabled() |
| 3417 | + @math_sdp_precision("ieee") |
3416 | 3418 | def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3417 | 3419 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3418 | 3420 | scale: str): |
@@ -3528,6 +3530,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3528 | 3530 | ) |
3529 | 3531 | @parametrize("scale", [None, "l1"]) |
3530 | 3532 | @tf32_enabled() |
| 3533 | + @math_sdp_precision("ieee") |
3531 | 3534 | def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, |
3532 | 3535 | seq_len_k: int, head_dim: int, is_causal: bool, |
3533 | 3536 | dropout_p: float, dtype: torch.dtype, |
@@ -3641,6 +3644,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3641 | 3644 | @parametrize("enable_gqa", [True, False]) |
3642 | 3645 | @parametrize("n_heads", [[16, 8], [10, 2]]) |
3643 | 3646 | @tf32_enabled() |
| 3647 | + @math_sdp_precision("ieee") |
3644 | 3648 | def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3645 | 3649 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3646 | 3650 | scale: str, enable_gqa: bool, n_heads: list[int]): |
@@ -3786,6 +3790,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le |
3786 | 3790 | @parametrize("scale", [None, "l1"]) |
3787 | 3791 | @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) |
3788 | 3792 | @tf32_enabled() |
| 3793 | + @math_sdp_precision("ieee") |
3789 | 3794 | def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, |
3790 | 3795 | seq_len_q: int, seq_len_k: int, |
3791 | 3796 | head_dim: int, |
@@ -4100,6 +4105,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device): |
4100 | 4105 | @parametrize("dtype", [torch.float16]) |
4101 | 4106 | @parametrize("scale", [None, "l1"]) |
4102 | 4107 | @parametrize("is_causal", [True, False]) |
| 4108 | + @math_sdp_precision("ieee") |
4103 | 4109 | def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, |
4104 | 4110 | head_dim: int, dropout_p: float, dtype: torch.dtype, |
4105 | 4111 | scale: str, is_causal: bool): |
|
0 commit comments