forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 75
Add torch.backends.cuda.math_sdp.fp32_precision #2844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anatoliylitv
wants to merge
11
commits into
rocm7.1_internal_testing
Choose a base branch
from
anatoliylitv/rocm_7_1_internal_testing_Add-torch.backends.cuda.math_sdp.fp32_precision
base: rocm7.1_internal_testing
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add torch.backends.cuda.math_sdp.fp32_precision #2844
anatoliylitv
wants to merge
11
commits into
rocm7.1_internal_testing
from
anatoliylitv/rocm_7_1_internal_testing_Add-torch.backends.cuda.math_sdp.fp32_precision
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…den to use ieee rather than tf32
|
Jenkins build for f1fbfb6a324c9faefb86149ad2a8aeeae0c88088 commit finished as FAILURE |
|
Jenkins build for b27caaba15c43257f8170853c3835876718ca5cb commit finished as FAILURE |
|
Jenkins build for b5744a1c7d6873109d3f1b7f17e4165a53d60ba6 commit finished as FAILURE |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
This PR adds a new float32 precision API
torch.backends.cuda.math_sdp.fp32_precision to configure fp32 precision
behavior of SDPBackend.MATH
Rationale
The test/test_transformers.py testing suite calculates the numerical
tolerance by comparing output tensors from the same precision ("reference")
and higher precision ("golden"), both calculated by SDPBackend.MATH.
However, the golden output is calculated with TF32 rather than FP32, which in
fact is less accurate than the FA/ME backend if they used IEEE rather than
TF32 for their accumulation.
The loss of precison causes false negatives in SDPA tests like
TestSDPACudaOnlyCUDA.test_flash_attention_vs_math_ref_grads_batch_size_8_seq_len_q_143_seq_len_k_4_head_dim_203_is_causal_False_dropout_p_0_22_float16_scale_l1_enable_gqa_True_n_heads1_cuda_float16
, at least on ROCM platform. The false negative disappears after forcing
higher_precision_dtype = torch.float64
Major Changes
To restore the precision of golden output, a new API
torch.backends.cuda.math_sdp.fp32_precision is introduced, which allows
configuration of "matmul" precision during SDPBackend.MATH, and a new
decorator @math_sdp_precision("ieee") is added to all tests that use
check_out_and_grad. At last, an assert is added to the inner most function
_check_equal as a sanity check to ensure math_sdp has the right precison
configured for torch.float32 golden tensors.
Known Issues
The backward phase honors the configuration when calling backward(), regardless
the configuration when creating the graph.