Add torch.backends.cuda.math_sdp.fp32_precision #2848
+75
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
This PR adds a new float32 precision API
torch.backends.cuda.math_sdp.fp32_precision to configure fp32 precision
behavior of SDPBackend.MATH
Rationale
The test/test_transformers.py testing suite calculates the numerical
tolerance by comparing output tensors from the same precision ("reference")
and higher precision ("golden"), both calculated by SDPBackend.MATH.
However, the golden output is calculated with TF32 rather than FP32, which in
fact is less accurate than the FA/ME backend if they used IEEE rather than
TF32 for their accumulation.
The loss of precison causes false negatives in SDPA tests like
TestSDPACudaOnlyCUDA.test_flash_attention_vs_math_ref_grads_batch_size_8_seq_len_q_143_seq_len_k_4_head_dim_203_is_causal_False_dropout_p_0_22_float16_scale_l1_enable_gqa_True_n_heads1_cuda_float16
, at least on ROCM platform. The false negative disappears after forcing
higher_precision_dtype = torch.float64
Major Changes
To restore the precision of golden output, a new API
torch.backends.cuda.math_sdp.fp32_precision is introduced, which allows
configuration of "matmul" precision during SDPBackend.MATH, and a new
decorator @math_sdp_precision("ieee") is added to all tests that use
check_out_and_grad. At last, an assert is added to the inner most function
_check_equal as a sanity check to ensure math_sdp has the right precison
configured for torch.float32 golden tensors.
Known Issues
The backward phase honors the configuration when calling backward(), regardless
the configuration when creating the graph.