Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ Float32Op str2op(const std::string& name) {
return Float32Op::RNN;
else if (name == "matmul")
return Float32Op::MATMUL;
else if (name == "math_sdp")
return Float32Op::MATH_SDP;
TORCH_CHECK(false, "Unknown op: ", name);
}

Expand Down
33 changes: 32 additions & 1 deletion aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum class CuBLASReductionOption : uint8_t {
DisallowReducedPrecisionDisallowSplitK = 2,
};
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL, MATH_SDP };
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };

TORCH_API Float32Backend str2backend(const std::string& name);
Expand Down Expand Up @@ -512,6 +512,7 @@ class TORCH_API Context {
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
? Float32Precision::NONE
: Float32Precision::TF32},
{{Float32Backend::CUDA, Float32Op::MATH_SDP}, Float32Precision::NONE},
};

Allocator* prev_allocator_ptr_{nullptr};
Expand Down Expand Up @@ -684,6 +685,36 @@ struct TORCH_API NoTF32Guard {
bool changed = false;
};

template <Float32Backend target_backend, Float32Op target_op>
struct Fp32PrecisonGuard {
Fp32PrecisonGuard(const Float32Precision new_precision) {
if (new_precision == Float32Precision::NONE) {
return;
}
saved_precision =
globalContext().float32Precision(target_backend, target_op);
changed = (new_precision != saved_precision);
if (changed) {
globalContext().setFloat32Precision(
target_backend, target_op, new_precision);
}
}
Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete;
Fp32PrecisonGuard(const Fp32PrecisonGuard&) = delete;
Fp32PrecisonGuard& operator=(const Fp32PrecisonGuard&) = delete;
Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete;
~Fp32PrecisonGuard() {
if (changed) {
globalContext().setFloat32Precision(
target_backend, target_op, saved_precision);
}
}

private:
Float32Precision saved_precision;
bool changed = false;
};

struct TORCH_API ROCmBackwardPassGuard {
ROCmBackwardPassGuard();
ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
? value.to(at::kFloat)
: value;
auto attn_mask = attn_mask_;
const auto math_sdp_precision = at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATH_SDP);
// Temporarily override matmul precision with value from cuda.math_sdp
// IEEE should be used when use fp32+math backend as golden reference.
at::Fp32PrecisonGuard<at::Float32Backend::CUDA, at::Float32Op::MATMUL> fp32guard(math_sdp_precision);

// Naive, composite implementation defined here.

// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,7 @@
"PropModule",
# torch.backends.cuda
"cuBLASModule",
"MathSDPModule",
"cuFFTPlanCache",
"cuFFTPlanCacheAttrContextProp",
"cuFFTPlanCacheManager",
Expand Down
12 changes: 12 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
tf32_on_and_off,
tf32_enabled,
math_sdp_precision,
)

if TEST_FAIRSEQ:
Expand Down Expand Up @@ -126,6 +127,12 @@ def _check_equal(
_check_equal(gold, ref, tst, fudge_factor, tensor_name)
return

if golden.is_cuda and golden.dtype == torch.float32:
assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", (
"Testing script error: FP32 golden tensor must be calculated with IEEE"
" precision. Add @math_sdp_precision('ieee') to related tests to fix it."
)

# Compute error between golden
test_error = (golden - test).abs().max()
ref_error = (golden - reference).abs().max()
Expand Down Expand Up @@ -3383,6 +3390,7 @@ def test_mem_eff_backwards_determinism(self, device):
)
@parametrize("scale", [None, "l1"])
@tf32_enabled()
@math_sdp_precision("ieee")
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str):
Expand Down Expand Up @@ -3498,6 +3506,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
)
@parametrize("scale", [None, "l1"])
@tf32_enabled()
@math_sdp_precision("ieee")
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
seq_len_k: int, head_dim: int, is_causal: bool,
dropout_p: float, dtype: torch.dtype,
Expand Down Expand Up @@ -3611,6 +3620,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
@tf32_enabled()
@math_sdp_precision("ieee")
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: list[int]):
Expand Down Expand Up @@ -3756,6 +3766,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
@parametrize("scale", [None, "l1"])
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
@tf32_enabled()
@math_sdp_precision("ieee")
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
seq_len_q: int, seq_len_k: int,
head_dim: int,
Expand Down Expand Up @@ -4070,6 +4081,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device):
@parametrize("dtype", [torch.float16])
@parametrize("scale", [None, "l1"])
@parametrize("is_causal", [True, False])
@math_sdp_precision("ieee")
def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
head_dim: int, dropout_p: float, dtype: torch.dtype,
scale: str, is_causal: bool):
Expand Down
14 changes: 14 additions & 0 deletions torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"cuFFTPlanCache",
"cuFFTPlanCacheManager",
"cuBLASModule",
"MathSDPModule",
"preferred_linalg_library",
"preferred_blas_library",
"preferred_rocm_fa_library",
Expand Down Expand Up @@ -206,6 +207,18 @@ def __setattr__(self, name, value):
raise AttributeError("Unknown attribute " + name)


class MathSDPModule:
def __getattr__(self, name):
if name == "fp32_precision":
return torch._C._get_fp32_precision_getter("cuda", "math_sdp")
raise AttributeError("Unknown attribute " + name)

def __setattr__(self, name, value):
if name == "fp32_precision":
return torch._C._set_fp32_precision_setter("cuda", "math_sdp", value)
raise AttributeError("Unknown attribute " + name)


_LinalgBackends = {
"default": torch._C._LinalgBackend.Default,
"cusolver": torch._C._LinalgBackend.Cusolver,
Expand Down Expand Up @@ -591,3 +604,4 @@ def sdp_kernel(

cufft_plan_cache = cuFFTPlanCacheManager()
matmul = cuBLASModule()
math_sdp = MathSDPModule()
9 changes: 9 additions & 0 deletions torch/testing/_internal/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ def tf32_enabled():
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul


@contextlib.contextmanager
def math_sdp_precision(target_precision: str):
saved_precision = torch.backends.cuda.math_sdp.fp32_precision
try:
torch.backends.cuda.math_sdp.fp32_precision = target_precision
yield
finally:
torch.backends.cuda.math_sdp.fp32_precision = saved_precision

# This is a wrapper that wraps a test to run this test twice, one with
# allow_tf32=True, another with allow_tf32=False. When running with
# allow_tf32=True, it will use reduced precision as specified by the
Expand Down