Skip to content

Conversation

Valentine233
Copy link
Collaborator

Description:

  1. Reuse the schema of qscaled_dot_product, and extend the FP8 dtype.
  2. Support the fused attention and fallback math kernels for FP8 SDPA.
  3. Support the pattern match for FP8 SDPA.

Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2689

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 85c0ad7 with merge base 6e9bf26 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2025
@Valentine233 Valentine233 marked this pull request as draft August 5, 2025 07:13
@Valentine233 Valentine233 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 5, 2025
@Valentine233
Copy link
Collaborator Author

@CaoE @jianan-gu Please help review, thanks~

template <typename scalar_t, typename mask_t,
int64_t q_split_size, int64_t kv_split_size>
inline typename std::enable_if_t<std::is_same_v<scalar_t, at::Float8_e4m3fn>, void>
fp8_sdpa_fused_kernel_impl(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to merge int8 and fp8 implementations ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can not merge these two implementations, as fp8 uses flash attention while int8 does not.

@@ -157,7 +157,7 @@ def _check_common(
)
@config.patch({"freezing": True})
def _test_sdpa_int8_rewriter(self):
from torch.export import export_for_training
from torch.export import export
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this test covers fp8, we'd better rename it.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the file name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified.

auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
_store(out + i, tmp4, size - i);
}
val = vec_tmp_sum.reduce_add();
Copy link

@CaoE CaoE Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need NaN guard ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would not get an extreme large value here, because this is part of the safe-softmax where the max value has been subtracted before.

return output.transpose(1, 2);
} else {
#endif // CPU_CAPABILITY_AVX512
std::cout << "int8_sdpa_math_kernel" << std::endl;
Copy link

@CaoE CaoE Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be an omission. Remove this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified.

#ifdef CPU_CAPABILITY_AVX512
if (at::native::cpublas::could_pack(dtype)) {
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
std::cout << "int8_sdpa_fused_kernel" << std::endl;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be an omission.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified.

return output.transpose(1, 2);
} else {
#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32
std::cout << "fp8_sdpa_math_kernel" << std::endl;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified.

return sdpa_int8_math_kernel(query, key, value,
if (dtype == at::ScalarType::Byte) {
#ifdef CPU_CAPABILITY_AVX512
if (at::native::cpublas::could_pack(dtype)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always need pack on supported platforms ? Is there any cases where do packing is slower than plain format ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion!
For the cases we care about, for example landing zone models, we have confirmed that packing is better.
For the general cases, we can tune for need_pack in the future.

// CPUBLAS_BRGEMM_F8F8F32 is defined if FP8 BRGEMM is supported in PyTorch CPUBlas.
if (at::native::cpublas::could_pack(dtype)) {
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
std::cout << "fp8_sdpa_fused_kernel" << std::endl;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified.

@@ -1834,6 +2424,43 @@ at::Tensor sdpa_int8_math_kernel(
return output;
}

at::Tensor fp8_sdpa_math_kernel(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any tests for the ref implementation ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ref path can be tested when CPUBLAS_BRGEMM_F8F8F32 is not set and this MACRO is defined in PyTorch. So, the fused attention or math ref one, which kernel to go, depends on the version of PyTorch. I have locally validated the both cases.

@Valentine233 Valentine233 requested a review from CaoE August 25, 2025 07:04
@Valentine233 Valentine233 marked this pull request as ready for review August 25, 2025 07:46
@Valentine233 Valentine233 requested review from jansel, jerryzh168, drisspg and CaoE and removed request for CaoE August 27, 2025 02:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants