diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3310abfb41d54..793e97e9d9c88 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -70,6 +70,8 @@ Float32Op str2op(const std::string& name) { return Float32Op::RNN; else if (name == "matmul") return Float32Op::MATMUL; + else if (name == "math_sdp") + return Float32Op::MATH_SDP; TORCH_CHECK(false, "Unknown op: ", name); } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a4a26b5671e59..ac965cf6c0f37 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -46,7 +46,7 @@ enum class CuBLASReductionOption : uint8_t { DisallowReducedPrecisionDisallowSplitK = 2, }; enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN }; -enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL }; +enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL, MATH_SDP }; enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 }; TORCH_API Float32Backend str2backend(const std::string& name); @@ -512,6 +512,7 @@ class TORCH_API Context { float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST ? Float32Precision::NONE : Float32Precision::TF32}, + {{Float32Backend::CUDA, Float32Op::MATH_SDP}, Float32Precision::NONE}, }; Allocator* prev_allocator_ptr_{nullptr}; @@ -684,6 +685,36 @@ struct TORCH_API NoTF32Guard { bool changed = false; }; +template +struct Fp32PrecisonGuard { + Fp32PrecisonGuard(const Float32Precision new_precision) { + if (new_precision == Float32Precision::NONE) { + return; + } + saved_precision = + globalContext().float32Precision(target_backend, target_op); + changed = (new_precision != saved_precision); + if (changed) { + globalContext().setFloat32Precision( + target_backend, target_op, new_precision); + } + } + Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete; + Fp32PrecisonGuard(const Fp32PrecisonGuard&) = delete; + Fp32PrecisonGuard& operator=(const Fp32PrecisonGuard&) = delete; + Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete; + ~Fp32PrecisonGuard() { + if (changed) { + globalContext().setFloat32Precision( + target_backend, target_op, saved_precision); + } + } + + private: + Float32Precision saved_precision; + bool changed = false; +}; + struct TORCH_API ROCmBackwardPassGuard { ROCmBackwardPassGuard(); ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete; diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d4..4e4d89b2a41d7 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -868,6 +868,11 @@ std::tuple _scaled_dot_product_attention_math( ? value.to(at::kFloat) : value; auto attn_mask = attn_mask_; + const auto math_sdp_precision = at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATH_SDP); + // Temporarily override matmul precision with value from cuda.math_sdp + // IEEE should be used when use fp32+math backend as golden reference. + at::Fp32PrecisonGuard fp32guard(math_sdp_precision); + // Naive, composite implementation defined here. // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for diff --git a/docs/source/conf.py b/docs/source/conf.py index b5a04df3e090b..7b733c1f86513 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2140,6 +2140,7 @@ "PropModule", # torch.backends.cuda "cuBLASModule", + "MathSDPModule", "cuFFTPlanCache", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCacheManager", diff --git a/test/test_transformers.py b/test/test_transformers.py index 4dea431246999..548a896298122 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -52,6 +52,7 @@ PLATFORM_SUPPORTS_CUDNN_ATTENTION, tf32_on_and_off, tf32_enabled, + math_sdp_precision, ) if TEST_FAIRSEQ: @@ -126,6 +127,12 @@ def _check_equal( _check_equal(gold, ref, tst, fudge_factor, tensor_name) return + if golden.is_cuda and golden.dtype == torch.float32: + assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", ( + "Testing script error: FP32 golden tensor must be calculated with IEEE" + " precision. Add @math_sdp_precision('ieee') to related tests to fix it." + ) + # Compute error between golden test_error = (golden - test).abs().max() ref_error = (golden - reference).abs().max() @@ -3383,6 +3390,7 @@ def test_mem_eff_backwards_determinism(self, device): ) @parametrize("scale", [None, "l1"]) @tf32_enabled() + @math_sdp_precision("ieee") def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): @@ -3498,6 +3506,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, ) @parametrize("scale", [None, "l1"]) @tf32_enabled() + @math_sdp_precision("ieee") def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -3611,6 +3620,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @parametrize("enable_gqa", [True, False]) @parametrize("n_heads", [[16, 8], [10, 2]]) @tf32_enabled() + @math_sdp_precision("ieee") def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str, enable_gqa: bool, n_heads: list[int]): @@ -3756,6 +3766,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le @parametrize("scale", [None, "l1"]) @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @tf32_enabled() + @math_sdp_precision("ieee") def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, @@ -4070,6 +4081,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device): @parametrize("dtype", [torch.float16]) @parametrize("scale", [None, "l1"]) @parametrize("is_causal", [True, False]) + @math_sdp_precision("ieee") def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, head_dim: int, dropout_p: float, dtype: torch.dtype, scale: str, is_causal: bool): diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index d62c2b05a1ea1..53d175f842990 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -12,6 +12,7 @@ "cuFFTPlanCache", "cuFFTPlanCacheManager", "cuBLASModule", + "MathSDPModule", "preferred_linalg_library", "preferred_blas_library", "preferred_rocm_fa_library", @@ -206,6 +207,18 @@ def __setattr__(self, name, value): raise AttributeError("Unknown attribute " + name) +class MathSDPModule: + def __getattr__(self, name): + if name == "fp32_precision": + return torch._C._get_fp32_precision_getter("cuda", "math_sdp") + raise AttributeError("Unknown attribute " + name) + + def __setattr__(self, name, value): + if name == "fp32_precision": + return torch._C._set_fp32_precision_setter("cuda", "math_sdp", value) + raise AttributeError("Unknown attribute " + name) + + _LinalgBackends = { "default": torch._C._LinalgBackend.Default, "cusolver": torch._C._LinalgBackend.Cusolver, @@ -591,3 +604,4 @@ def sdp_kernel( cufft_plan_cache = cuFFTPlanCacheManager() matmul = cuBLASModule() +math_sdp = MathSDPModule() diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 74dfe0c56c232..5eff4a8e1dc52 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -221,6 +221,15 @@ def tf32_enabled(): torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul +@contextlib.contextmanager +def math_sdp_precision(target_precision: str): + saved_precision = torch.backends.cuda.math_sdp.fp32_precision + try: + torch.backends.cuda.math_sdp.fp32_precision = target_precision + yield + finally: + torch.backends.cuda.math_sdp.fp32_precision = saved_precision + # This is a wrapper that wraps a test to run this test twice, one with # allow_tf32=True, another with allow_tf32=False. When running with # allow_tf32=True, it will use reduced precision as specified by the