-
Notifications
You must be signed in to change notification settings - Fork 322
[CPU][FP8] Support FP8 SDPA for CPU backend #2689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2689
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 85c0ad7 with merge base 6e9bf26 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@CaoE @jianan-gu Please help review, thanks~ |
template <typename scalar_t, typename mask_t, | ||
int64_t q_split_size, int64_t kv_split_size> | ||
inline typename std::enable_if_t<std::is_same_v<scalar_t, at::Float8_e4m3fn>, void> | ||
fp8_sdpa_fused_kernel_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to merge int8 and fp8 implementations ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can not merge these two implementations, as fp8 uses flash attention while int8 does not.
@@ -157,7 +157,7 @@ def _check_common( | |||
) | |||
@config.patch({"freezing": True}) | |||
def _test_sdpa_int8_rewriter(self): | |||
from torch.export import export_for_training | |||
from torch.export import export |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this test covers fp8, we'd better rename it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the file name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified.
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); | ||
_store(out + i, tmp4, size - i); | ||
} | ||
val = vec_tmp_sum.reduce_add(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function need NaN guard ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would not get an extreme large value here, because this is part of the safe-softmax where the max value has been subtracted before.
torchao/csrc/cpu/quantized_sdpa.cpp
Outdated
return output.transpose(1, 2); | ||
} else { | ||
#endif // CPU_CAPABILITY_AVX512 | ||
std::cout << "int8_sdpa_math_kernel" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be an omission. Remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified.
torchao/csrc/cpu/quantized_sdpa.cpp
Outdated
#ifdef CPU_CAPABILITY_AVX512 | ||
if (at::native::cpublas::could_pack(dtype)) { | ||
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); | ||
std::cout << "int8_sdpa_fused_kernel" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be an omission.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified.
torchao/csrc/cpu/quantized_sdpa.cpp
Outdated
return output.transpose(1, 2); | ||
} else { | ||
#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32 | ||
std::cout << "fp8_sdpa_math_kernel" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified.
return sdpa_int8_math_kernel(query, key, value, | ||
if (dtype == at::ScalarType::Byte) { | ||
#ifdef CPU_CAPABILITY_AVX512 | ||
if (at::native::cpublas::could_pack(dtype)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we always need pack on supported platforms ? Is there any cases where do packing is slower than plain format ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion!
For the cases we care about, for example landing zone models, we have confirmed that packing is better.
For the general cases, we can tune for need_pack in the future.
torchao/csrc/cpu/quantized_sdpa.cpp
Outdated
// CPUBLAS_BRGEMM_F8F8F32 is defined if FP8 BRGEMM is supported in PyTorch CPUBlas. | ||
if (at::native::cpublas::could_pack(dtype)) { | ||
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); | ||
std::cout << "fp8_sdpa_fused_kernel" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified.
@@ -1834,6 +2424,43 @@ at::Tensor sdpa_int8_math_kernel( | |||
return output; | |||
} | |||
|
|||
at::Tensor fp8_sdpa_math_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any tests for the ref implementation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ref path can be tested when CPUBLAS_BRGEMM_F8F8F32
is not set and this MACRO is defined in PyTorch. So, the fused attention or math ref one, which kernel to go, depends on the version of PyTorch. I have locally validated the both cases.
123fd8c
to
e0241ee
Compare
e0241ee
to
85c0ad7
Compare
Description:
qscaled_dot_product
, and extend the FP8 dtype.