-
Notifications
You must be signed in to change notification settings - Fork 64
Re-implement FlashAttention with new Xe atoms #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I will break up this large commit into self-contained smaller commits after review is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this here? This isn't flash attention specific, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).
| FragSRow k_rem_mask; | ||
| int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; | ||
| for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { | ||
| k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original S already contains NaN , fmin(NaN, NaN) = NaN, will propagates the NaN to softmax. This can corrupt row-wise sum and max leading to NaN in the final output O, could we have better k_rem_mask here to avoid this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ClarkChin08 Can you explain your concern a bit more? If original S has a NaN value in bounds, then that indicates either an overflow from very badly scaled data or an inf/NaN input, and there's no safe way to numerically recover from that (we can't easily guess what the right value should be in place of that NaN). If S has a NaN value out of bounds, then the fmin with -inf will produce -inf, so the NaN will be removed and not corrupt the softmax value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that if NaN appears in the valid range of S, it's likely a symptom of upstream issues like bad scaling or invalid inputs, and trying to "fix" it in the kernel can be tricky, especially in low-precision formats like fp8/fp4 where overflows are common.
Perhaps adding an optional debug mode to scan for NaNs/invalid inputs in S before softmax could help users identify issues early.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yes, that could be helpful. Perhaps this could take the form of a general helper function that scans for NaNs in an arbitrary tensor and aborts if any are found.
| using ElementS = typename TiledMMAQK::ValTypeD; | ||
|
|
||
| using SingleFragA = FragC<TiledMMAPV>; // (atom val,q',v') | ||
| using FragA = expand_sg_fragment_t<SingleFragA, 1, VTiles>; // (atom val,q',v',VV) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FragA's size scales with q' * v' * VTiles * sizeof(ElementA), which causes high register pressure, especially for large head_dim (>128) leading to register spill to SLM and causing performance degradation, maybe we could have a register-friendly design?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ClarkChin08 Thanks for the comments. A is the accumulator for the outer k-loop, so all of it needs to be maintained in registers (or at least in SLM -- but I don't think SLM would work well; you lose L1 and anyway SLM is only 1/4 the size of register space). If the register pressure is too high, you can parallelize across workgroups in the V/O head dimension. (i.e. the n tile size of TiledMMAPV can be < headdim_vo).
Anyway, this is the same design as the previous FlashAttention implementation. We can try other designs that are more friendly for larger head dimensions in future PRs. For instance, the oneDNN implementation has a different vectorization and parallelization approach that is better at handling larger head dimensions.
|
The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256
However, when seq_len_kv is changed to 512 or higher, the example passes successfully. |
@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit? |
af2f402 to
326669e
Compare
Yes, passed now. |
| { | ||
| static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy"); | ||
|
|
||
| auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use template parameters for XE_1D_STSM Sdtype Ddtype?
If you only load 1 bf16/fp16, there will be an error.
|
|
||
| auto atom_shape = make_shape(_1{}, size(SrcLayout{})); | ||
|
|
||
| auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_16 means subgroup size must be 16?
|
|
||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < rA.size(); i++) | ||
| rA(i) *= broadcast<0>(rA_sum, rA, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not broadcast according to rAsum.size()?
loop i row
rA_sum_recip(i) = ElementA(1) / rA_sum(i);
rA_sum_recip(i) *=brocast (rA_sum, rA, i);
loop j col
rA(i, j) = rA(i, j) * rA_sum_recip(i)
| } | ||
|
|
||
| clear(tArA); | ||
| fill(tA_max, ElementA(-1000.0f)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not fill (-inf)
|
|
||
| /* K prefetch */ | ||
| for (int D = 0; D < size<4>(pKgK); D++) { | ||
| prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is prefetch out of bounds?
|
|
||
| auto &p = params.kernel; | ||
| ProblemShape const& s = p.shape; | ||
| int head_group_q = s.num_heads_q / s.num_heads_kv; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use template parameters to define head_group_q.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In GQA calculations, the size of each work group to be calculated should be determined at compile time.
This PR updates FlashAttention to the new copy/MMA atoms.
Changes:
Current status: prefill/decode examples working, similar/better performance to old examples.
Known issues:
Additional features (causal masking, variable sequence lengths, etc.) to be added later.
Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.