Skip to content

Commit 8654659

Browse files
committed
test/test_transformers: decorate all tests that uses fp32 math as golden to use ieee rather than tf32
1 parent e7c76da commit 8654659

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/test_transformers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5555
tf32_on_and_off,
5656
tf32_enabled,
57+
math_sdp_precision,
5758
)
5859

5960
if TEST_FAIRSEQ:
@@ -3413,6 +3414,7 @@ def test_mem_eff_backwards_determinism(self, device):
34133414
)
34143415
@parametrize("scale", [None, "l1"])
34153416
@tf32_enabled()
3417+
@math_sdp_precision("ieee")
34163418
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
34173419
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
34183420
scale: str):
@@ -3528,6 +3530,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
35283530
)
35293531
@parametrize("scale", [None, "l1"])
35303532
@tf32_enabled()
3533+
@math_sdp_precision("ieee")
35313534
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
35323535
seq_len_k: int, head_dim: int, is_causal: bool,
35333536
dropout_p: float, dtype: torch.dtype,
@@ -3641,6 +3644,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
36413644
@parametrize("enable_gqa", [True, False])
36423645
@parametrize("n_heads", [[16, 8], [10, 2]])
36433646
@tf32_enabled()
3647+
@math_sdp_precision("ieee")
36443648
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
36453649
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
36463650
scale: str, enable_gqa: bool, n_heads: list[int]):
@@ -3786,6 +3790,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
37863790
@parametrize("scale", [None, "l1"])
37873791
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
37883792
@tf32_enabled()
3793+
@math_sdp_precision("ieee")
37893794
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
37903795
seq_len_q: int, seq_len_k: int,
37913796
head_dim: int,
@@ -4100,6 +4105,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device):
41004105
@parametrize("dtype", [torch.float16])
41014106
@parametrize("scale", [None, "l1"])
41024107
@parametrize("is_causal", [True, False])
4108+
@math_sdp_precision("ieee")
41034109
def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
41044110
head_dim: int, dropout_p: float, dtype: torch.dtype,
41054111
scale: str, is_causal: bool):

0 commit comments

Comments
 (0)