From 6999b0151b90e451fba27ba4b6ff043db2c71dfc Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 20:17:56 +0000 Subject: [PATCH 01/11] Add torch.backends.cuda.math_sdp.fp32_precision --- aten/src/ATen/Context.cpp | 2 ++ aten/src/ATen/Context.h | 3 ++- torch/backends/cuda/__init__.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3310abfb41d54..793e97e9d9c88 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -70,6 +70,8 @@ Float32Op str2op(const std::string& name) { return Float32Op::RNN; else if (name == "matmul") return Float32Op::MATMUL; + else if (name == "math_sdp") + return Float32Op::MATH_SDP; TORCH_CHECK(false, "Unknown op: ", name); } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a4a26b5671e59..f5a89b6a51d60 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -46,7 +46,7 @@ enum class CuBLASReductionOption : uint8_t { DisallowReducedPrecisionDisallowSplitK = 2, }; enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN }; -enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL }; +enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL, MATH_SDP }; enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 }; TORCH_API Float32Backend str2backend(const std::string& name); @@ -512,6 +512,7 @@ class TORCH_API Context { float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST ? Float32Precision::NONE : Float32Precision::TF32}, + {{Float32Backend::CUDA, Float32Op::MATH_SDP}, Float32Precision::NONE}, }; Allocator* prev_allocator_ptr_{nullptr}; diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index d62c2b05a1ea1..7aad461dbd682 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -205,6 +205,17 @@ def __setattr__(self, name, value): return torch._C._set_fp32_precision_setter("cuda", "matmul", value) raise AttributeError("Unknown attribute " + name) +class MathSDPModule: + def __getattr__(self, name): + if name == "fp32_precision": + return torch._C._get_fp32_precision_getter("cuda", "math_sdp") + raise AttributeError("Unknown attribute " + name) + + def __setattr__(self, name, value): + if name == "fp32_precision": + return torch._C._set_fp32_precision_setter("cuda", "math_sdp", value) + raise AttributeError("Unknown attribute " + name) + _LinalgBackends = { "default": torch._C._LinalgBackend.Default, @@ -591,3 +602,4 @@ def sdp_kernel( cufft_plan_cache = cuFFTPlanCacheManager() matmul = cuBLASModule() +math_sdp = MathSDPModule() From b365da97710378abb317115cc6cba95ed620654d Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 20:45:52 +0000 Subject: [PATCH 02/11] Make torch.backends.cuda.math_sdp.fp32_precision effective for math_sdp --- aten/src/ATen/Context.h | 26 +++++++++++++++++++ .../ATen/native/transformers/attention.cpp | 5 ++++ 2 files changed, 31 insertions(+) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index f5a89b6a51d60..f769c2fb672f3 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -685,6 +685,32 @@ struct TORCH_API NoTF32Guard { bool changed = false; }; +template +struct Fp32PrecisonGuard { + Fp32PrecisonGuard(const Float32Precision new_precision) { + if (new_precision == Float32Precision::NONE) { + return ; + } + saved_precision = float32Precision(target_backend, target_op); + changed = (new_precision == saved_precision); + if (changed) { + setFloat32Precision(target_backend, target_op, new_precision); + } + } + Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete; + Fp32PrecisonGuard(const Fp32PrecisonGuard&) = delete; + Fp32PrecisonGuard& operator=(const Fp32PrecisonGuard&) = delete; + Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete; + ~Fp32PrecisonGuard() { + if (changed) { + setFloat32Precision(target_backend, target_op, saved_precision); + } + } + private: + Float32Precision saved_precision; + bool changed = false; +}; + struct TORCH_API ROCmBackwardPassGuard { ROCmBackwardPassGuard(); ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete; diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d4..4e4d89b2a41d7 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -868,6 +868,11 @@ std::tuple _scaled_dot_product_attention_math( ? value.to(at::kFloat) : value; auto attn_mask = attn_mask_; + const auto math_sdp_precision = at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATH_SDP); + // Temporarily override matmul precision with value from cuda.math_sdp + // IEEE should be used when use fp32+math backend as golden reference. + at::Fp32PrecisonGuard fp32guard(math_sdp_precision); + // Naive, composite implementation defined here. // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for From cca52039f2de4d31d0c165acc0019cee349e8f96 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 20:55:06 +0000 Subject: [PATCH 03/11] torch/testing: add ctx mananger math_sdp_precision --- torch/testing/_internal/common_cuda.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 74dfe0c56c232..5eff4a8e1dc52 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -221,6 +221,15 @@ def tf32_enabled(): torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul +@contextlib.contextmanager +def math_sdp_precision(target_precision: str): + saved_precision = torch.backends.cuda.math_sdp.fp32_precision + try: + torch.backends.cuda.math_sdp.fp32_precision = target_precision + yield + finally: + torch.backends.cuda.math_sdp.fp32_precision = saved_precision + # This is a wrapper that wraps a test to run this test twice, one with # allow_tf32=True, another with allow_tf32=False. When running with # allow_tf32=True, it will use reduced precision as specified by the From b98605ed4d9b4e06db4bc5d94e15477c1e38a4d4 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 20:56:15 +0000 Subject: [PATCH 04/11] test/test_transformers: decorate all tests that uses fp32 math as golden to use ieee rather than tf32 --- test/test_transformers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 4dea431246999..230137d44ea9b 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -52,6 +52,7 @@ PLATFORM_SUPPORTS_CUDNN_ATTENTION, tf32_on_and_off, tf32_enabled, + math_sdp_precision, ) if TEST_FAIRSEQ: @@ -3383,6 +3384,7 @@ def test_mem_eff_backwards_determinism(self, device): ) @parametrize("scale", [None, "l1"]) @tf32_enabled() + @math_sdp_precision("ieee") def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): @@ -3498,6 +3500,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, ) @parametrize("scale", [None, "l1"]) @tf32_enabled() + @math_sdp_precision("ieee") def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -3611,6 +3614,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @parametrize("enable_gqa", [True, False]) @parametrize("n_heads", [[16, 8], [10, 2]]) @tf32_enabled() + @math_sdp_precision("ieee") def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str, enable_gqa: bool, n_heads: list[int]): @@ -3756,6 +3760,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le @parametrize("scale", [None, "l1"]) @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @tf32_enabled() + @math_sdp_precision("ieee") def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, @@ -4070,6 +4075,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device): @parametrize("dtype", [torch.float16]) @parametrize("scale", [None, "l1"]) @parametrize("is_causal", [True, False]) + @math_sdp_precision("ieee") def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, head_dim: int, dropout_p: float, dtype: torch.dtype, scale: str, is_causal: bool): From b9eb99e9f0b9aeb44cbb15454a8a1703bf7fd9a7 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 21:37:05 +0000 Subject: [PATCH 05/11] fix build error --- aten/src/ATen/Context.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index f769c2fb672f3..06e8f789b9f8f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -691,10 +691,10 @@ struct Fp32PrecisonGuard { if (new_precision == Float32Precision::NONE) { return ; } - saved_precision = float32Precision(target_backend, target_op); + saved_precision = globalContext().float32Precision(target_backend, target_op); changed = (new_precision == saved_precision); if (changed) { - setFloat32Precision(target_backend, target_op, new_precision); + globalContext().setFloat32Precision(target_backend, target_op, new_precision); } } Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete; @@ -703,7 +703,7 @@ struct Fp32PrecisonGuard { Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete; ~Fp32PrecisonGuard() { if (changed) { - setFloat32Precision(target_backend, target_op, saved_precision); + globalContext().setFloat32Precision(target_backend, target_op, saved_precision); } } private: From 9fd54bba33416322c9703107fd30ae17cfc24cea Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 22:23:07 +0000 Subject: [PATCH 06/11] test/test_transformers: sanity check of golden tensor --- test/test_transformers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 230137d44ea9b..7f56a372eedc9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -127,6 +127,12 @@ def _check_equal( _check_equal(gold, ref, tst, fudge_factor, tensor_name) return + if golden.is_cuda and golden.dtype == torch.float32: + assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", ( + "Testing script error: FP32 golden tensor must be calculated with IEEE" + " precision" + ) + # Compute error between golden test_error = (golden - test).abs().max() ref_error = (golden - reference).abs().max() From ec5378b1add8f9f26120f3daa55ec69e507c38a5 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 22:23:44 +0000 Subject: [PATCH 07/11] fix Fp32PrecisonGuard --- aten/src/ATen/Context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 06e8f789b9f8f..0343980ca21b5 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -692,7 +692,7 @@ struct Fp32PrecisonGuard { return ; } saved_precision = globalContext().float32Precision(target_backend, target_op); - changed = (new_precision == saved_precision); + changed = (new_precision != saved_precision); if (changed) { globalContext().setFloat32Precision(target_backend, target_op, new_precision); } From d9851f9c211f1741ecfca55601766aea135ad304 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 22:29:26 +0000 Subject: [PATCH 08/11] more documentation --- test/test_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index 7f56a372eedc9..548a896298122 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -130,7 +130,7 @@ def _check_equal( if golden.is_cuda and golden.dtype == torch.float32: assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", ( "Testing script error: FP32 golden tensor must be calculated with IEEE" - " precision" + " precision. Add @math_sdp_precision('ieee') to related tests to fix it." ) # Compute error between golden From f1fbfb6a324c9faefb86149ad2a8aeeae0c88088 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 6 Nov 2025 00:08:04 +0000 Subject: [PATCH 09/11] fix lint --- aten/src/ATen/Context.h | 14 +++++++++----- torch/backends/cuda/__init__.py | 1 + 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 0343980ca21b5..ac965cf6c0f37 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -685,16 +685,18 @@ struct TORCH_API NoTF32Guard { bool changed = false; }; -template +template struct Fp32PrecisonGuard { Fp32PrecisonGuard(const Float32Precision new_precision) { if (new_precision == Float32Precision::NONE) { - return ; + return; } - saved_precision = globalContext().float32Precision(target_backend, target_op); + saved_precision = + globalContext().float32Precision(target_backend, target_op); changed = (new_precision != saved_precision); if (changed) { - globalContext().setFloat32Precision(target_backend, target_op, new_precision); + globalContext().setFloat32Precision( + target_backend, target_op, new_precision); } } Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete; @@ -703,9 +705,11 @@ struct Fp32PrecisonGuard { Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete; ~Fp32PrecisonGuard() { if (changed) { - globalContext().setFloat32Precision(target_backend, target_op, saved_precision); + globalContext().setFloat32Precision( + target_backend, target_op, saved_precision); } } + private: Float32Precision saved_precision; bool changed = false; diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 7aad461dbd682..b9a4139c9d3ed 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -205,6 +205,7 @@ def __setattr__(self, name, value): return torch._C._set_fp32_precision_setter("cuda", "matmul", value) raise AttributeError("Unknown attribute " + name) + class MathSDPModule: def __getattr__(self, name): if name == "fp32_precision": From b27caaba15c43257f8170853c3835876718ca5cb Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 2 Dec 2025 16:31:04 +0000 Subject: [PATCH 10/11] Follow the practice of cuBLASModule and ignore MathSDPModule in docs generation --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index b5a04df3e090b..7b733c1f86513 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2140,6 +2140,7 @@ "PropModule", # torch.backends.cuda "cuBLASModule", + "MathSDPModule", "cuFFTPlanCache", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCacheManager", From b5744a1c7d6873109d3f1b7f17e4165a53d60ba6 Mon Sep 17 00:00:00 2001 From: Anatoliy Litvinenko Date: Tue, 2 Dec 2025 13:39:43 -0600 Subject: [PATCH 11/11] Follow the practice of public cuBLASModule to include to __all__ for MathSDPModule --- torch/backends/cuda/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index b9a4139c9d3ed..53d175f842990 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -12,6 +12,7 @@ "cuFFTPlanCache", "cuFFTPlanCacheManager", "cuBLASModule", + "MathSDPModule", "preferred_linalg_library", "preferred_blas_library", "preferred_rocm_fa_library",