Skip to content

Conversation

@petercad
Copy link

@petercad petercad commented Oct 4, 2025

This PR updates FlashAttention to the new copy/MMA atoms.

Changes:

  • Prefill and decode unified into a single implementation, allowing simultaneous K and Q subgroup-level parallelization rather than an either-or.
  • GEMMs and softmax grouped together and the full k loop consolidated into an FMHA mainloop class.
    • This will facilitate further manual pipelining/overlap of GEMM with softmax.
  • Use new copy/MMA atoms and reorders to transparently support arbitrary data types.
  • Automatic copy/MMA operator selection.

Current status: prefill/decode examples working, similar/better performance to old examples.

Known issues:

  • Head size 192 decode config doesn't compile yet. Edit: fixed.
  • Strange SYCL compiler behavior/bug with tSrS->tArP reorder. Apparently the compiler believes there is UB somewhere and will omit a large section of the kernel as a result. For the moment, there's a direct copy as a workaround while I pin down the issue. I'm not able to reproduce this behavior with the reorder in isolation.

Additional features (causal masking, variable sequence lengths, etc.) to be added later.

Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.

@petercad petercad changed the title [Umbrella commit] Re-implement FlashAttention with new Xe atoms Re-implement FlashAttention with new Xe atoms Oct 4, 2025
@petercad
Copy link
Author

petercad commented Oct 4, 2025

I will break up this large commit into self-contained smaller commits after review is complete.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here? This isn't flash attention specific, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).

FragSRow k_rem_mask;
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original S already contains NaN , fmin(NaN, NaN) = NaN, will propagates the NaN to softmax. This can corrupt row-wise sum and max leading to NaN in the final output O, could we have better k_rem_mask here to avoid this case?

Copy link
Author

@petercad petercad Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 Can you explain your concern a bit more? If original S has a NaN value in bounds, then that indicates either an overflow from very badly scaled data or an inf/NaN input, and there's no safe way to numerically recover from that (we can't easily guess what the right value should be in place of that NaN). If S has a NaN value out of bounds, then the fmin with -inf will produce -inf, so the NaN will be removed and not corrupt the softmax value.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if NaN appears in the valid range of S, it's likely a symptom of upstream issues like bad scaling or invalid inputs, and trying to "fix" it in the kernel can be tricky, especially in low-precision formats like fp8/fp4 where overflows are common.
Perhaps adding an optional debug mode to scan for NaNs/invalid inputs in S before softmax could help users identify issues early.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes, that could be helpful. Perhaps this could take the form of a general helper function that scans for NaNs in an arbitrary tensor and aborts if any are found.

using ElementS = typename TiledMMAQK::ValTypeD;

using SingleFragA = FragC<TiledMMAPV>; // (atom val,q',v')
using FragA = expand_sg_fragment_t<SingleFragA, 1, VTiles>; // (atom val,q',v',VV)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FragA's size scales with q' * v' * VTiles * sizeof(ElementA), which causes high register pressure, especially for large head_dim (>128) leading to register spill to SLM and causing performance degradation, maybe we could have a register-friendly design?

Copy link
Author

@petercad petercad Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 Thanks for the comments. A is the accumulator for the outer k-loop, so all of it needs to be maintained in registers (or at least in SLM -- but I don't think SLM would work well; you lose L1 and anyway SLM is only 1/4 the size of register space). If the register pressure is too high, you can parallelize across workgroups in the V/O head dimension. (i.e. the n tile size of TiledMMAPV can be < headdim_vo).

Anyway, this is the same design as the previous FlashAttention implementation. We can try other designs that are more friendly for larger head dimensions in future PRs. For instance, the oneDNN implementation has a different vectorization and parallelization approach that is better at handling larger head dimensions.

@ClarkChin08
Copy link

ClarkChin08 commented Oct 23, 2025

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

output [991]: 2.696791 vs -nan

./examples/06_bmg_flash_attention/06_xe_fmha_fwd_decode_hdim128 --iterations=10 --batch=1 --num_heads_q=8 --seq_len_kv=256 --seq_len_qo=1 --num_heads_kv=8

However, when seq_len_kv is changed to 512 or higher, the example passes successfully.

@petercad
Copy link
Author

petercad commented Oct 23, 2025

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit?

@petercad petercad force-pushed the petercad/rearch_sdpa branch from af2f402 to 326669e Compare October 23, 2025 03:54
@ClarkChin08
Copy link

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit?

Yes, passed now.

{
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");

auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use template parameters for XE_1D_STSM Sdtype Ddtype?
If you only load 1 bf16/fp16, there will be an error.


auto atom_shape = make_shape(_1{}, size(SrcLayout{}));

auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{}));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_16 means subgroup size must be 16?


CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA.size(); i++)
rA(i) *= broadcast<0>(rA_sum, rA, i);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not broadcast according to rAsum.size()?
loop i row
rA_sum_recip(i) = ElementA(1) / rA_sum(i);
rA_sum_recip(i) *=brocast (rA_sum, rA, i);
loop j col
rA(i, j) = rA(i, j) * rA_sum_recip(i)

}

clear(tArA);
fill(tA_max, ElementA(-1000.0f));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not fill (-inf)


/* K prefetch */
for (int D = 0; D < size<4>(pKgK); D++) {
prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is prefetch out of bounds?


auto &p = params.kernel;
ProblemShape const& s = p.shape;
int head_group_q = s.num_heads_q / s.num_heads_kv;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use template parameters to define head_group_q.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In GQA calculations, the size of each work group to be calculated should be determined at compile time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request urgent PR requires a urgent attention (for release or blocking another PR)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants