From 1f0d33aeb5c00ed1cbf5aa76bbc7b1e74e882935 Mon Sep 17 00:00:00 2001 From: kareem Date: Thu, 25 Sep 2025 07:13:27 +0300 Subject: [PATCH 1/4] Attention sink support - Support for single sink logit in flash attention Decode - Add Sink to Softmax - Cmd line flag added to enable attention sink Signed-off-by: kareem --- .../collective/xe_flash_attn_decode_mma.hpp | 22 +++++---- .../xe_flash_attn_decode_softmax_epilogue.hpp | 16 +++++-- .../kernel/xe_flash_attn_decode.hpp | 13 +++++- .../06_bmg_decode_attention.cpp | 10 ++++- .../bmg_flash_attn_decode_runner.hpp | 45 ++++++++++++++++--- 5 files changed, 83 insertions(+), 23 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp index 6f83b63b1b..8eed85cdc2 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp @@ -60,8 +60,8 @@ CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { ///////////////////////////////////////////////////////////////////////////////////////////////// template + class ElementV_, class StrideV_, class ElementSink_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, + class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_, bool HasSink_> struct FlashDecodeMma { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -69,11 +69,11 @@ struct FlashDecodeMma { ///////////////////////////////////////////////////////////////////////////////////////////////// template + class ElementV_, class StrideV_, class ElementSink_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, + class GmemTiledCopyQ_, class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_, bool HasSink_> struct FlashDecodeMma, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_, - StrideV_, MMAOp_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, - GmemTiledCopyV_, CausalMask_, PagedKV_> { + StrideV_, ElementSink_, MMAOp_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, + GmemTiledCopyV_, CausalMask_, PagedKV_, HasSink_> { // // Type Aliases // @@ -88,6 +88,7 @@ struct FlashDecodeMma, ProblemShapeType_, Ele using StrideK = StrideK_; using ElementV = ElementV_; using StrideV = StrideV_; + using ElementSink = ElementSink_; using GmemTiledCopyQ = GmemTiledCopyQ_; using GmemTiledCopyK = GmemTiledCopyK_; using GmemTiledCopyV = GmemTiledCopyV_; @@ -95,6 +96,7 @@ struct FlashDecodeMma, ProblemShapeType_, Ele static constexpr bool CausalMask = CausalMask_; static constexpr bool PagedKV = PagedKV_; + static constexpr bool HasSink = HasSink_; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using MmaAtom = MMA_Atom; @@ -174,6 +176,8 @@ struct FlashDecodeMma, ProblemShapeType_, Ele int const* ptr_page_table; int page_size; int const* num_pages_per_seq; + // attention sink + ElementSink const* ptr_Sink; }; struct Params { @@ -186,6 +190,8 @@ struct FlashDecodeMma, ProblemShapeType_, Ele int const* ptr_page_table; int page_size; int const* num_pages_per_seq; + // attention sink + ElementSink const* ptr_Sink; }; // @@ -212,7 +218,7 @@ struct FlashDecodeMma, ProblemShapeType_, Ele XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, args.ptr_page_table, args.page_size, args.num_pages_per_seq}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, args.ptr_page_table, args.page_size, args.num_pages_per_seq, args.ptr_Sink}; } template @@ -430,7 +436,7 @@ struct FlashDecodeMma, ProblemShapeType_, Ele XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, params.ptr_page_table, params.page_size, params.num_pages_per_seq}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, params.ptr_page_table, params.page_size, params.num_pages_per_seq, params.ptr_Sink}; } } }; diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp index 60d0ad4b88..bb615b02c5 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp @@ -50,13 +50,13 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class FlashDecodeSoftmaxEpilogue { +template class FlashDecodeSoftmaxEpilogue { static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); }; -template -class FlashDecodeSoftmaxEpilogue { +template +class FlashDecodeSoftmaxEpilogue { public: // @@ -66,6 +66,7 @@ class FlashDecodeSoftmaxEpilogue using Element = Element_; static constexpr bool CausalMask = CausalMask_; + static constexpr bool HasSink = HasSink_; using GmemTiledCopyOut = void; @@ -149,7 +150,7 @@ class FlashDecodeSoftmaxEpilogue } template - CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, Element& max_val, FragSum& sum, + CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, Element& max_val, FragSum& sum, FragSum& sink_token, STensorMax& shmem_tensor_max, FragOut& out) { using FragAccLayout = typename FragAcc::layout_type; using FragOutLayout = typename FragOut::layout_type; @@ -162,6 +163,13 @@ class FlashDecodeSoftmaxEpilogue reduce_max(frag_s, shmem_tensor_max, max_val); + if constexpr (HasSink) { + if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) { + Element max_scale{max_val * params.scale}; + sum += sycl::native::exp2((sink_token- max_scale)); + } + } + if (!is_first) { auto sg = compat::get_nd_item<1>().get_sub_group(); const int sg_group_id = sg.get_group_id()[0]; diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp index 035e82584a..6bb44d38ef 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp @@ -69,6 +69,7 @@ class FMHADecode { using StrideK = typename CollectiveMainloop::StrideK; using ElementV = typename CollectiveMainloop::ElementV; using StrideV = typename CollectiveMainloop::StrideV; + using ElementSink = typename CollectiveMainloop::ElementSink; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -101,6 +102,7 @@ class FMHADecode { static constexpr bool CausalMask = CollectiveMainloop::CausalMask; static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + static constexpr bool HasSink = CollectiveMainloop::HasSink; static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 @@ -340,7 +342,14 @@ class FMHADecode { CollectiveMainloop collective_mma; ElementAccumulator max_reg = ElementAccumulator{-INFINITY}; + ElementAccumulator sink_token = ElementAccumulator{ 0 }; auto sum_reg = ElementAccumulator{0}; + if constexpr (HasSink) { + if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) { + max_reg = static_cast(mainloop_params.ptr_Sink[num_heads_coord]); + sink_token = max_reg; + } + } Tensor out_reg = make_tensor(AccumShape{}); clear(out_reg); @@ -391,7 +400,7 @@ class FMHADecode { } CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax.template operator()(split == 0, tSr, max_reg, sum_reg, shmem_max_tensor, out_reg); + softmax.template operator()(split == 0, tSr, max_reg, sum_reg, sink_token, shmem_max_tensor, out_reg); collective_mma.template mmaPV(out_reg, tSr, gV, out_reg, mainloop_params, is_KV_cache, curr_kv_tile_idx); @@ -455,7 +464,7 @@ class FMHADecode { } CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax.template operator()((kv_splits - 1) == 0, tSr, max_reg, sum_reg, shmem_max_tensor, out_reg); + softmax.template operator()((kv_splits - 1) == 0, tSr, max_reg, sum_reg, sink_token, shmem_max_tensor, out_reg); collective_mma.template mmaPV(out_reg, tSr, gV, out_reg, mainloop_params, false, curr_kv_tile_idx); diff --git a/examples/06_bmg_flash_attention/06_bmg_decode_attention.cpp b/examples/06_bmg_flash_attention/06_bmg_decode_attention.cpp index 0a92f1ca72..c129b1d867 100644 --- a/examples/06_bmg_flash_attention/06_bmg_decode_attention.cpp +++ b/examples/06_bmg_flash_attention/06_bmg_decode_attention.cpp @@ -69,8 +69,14 @@ int run_decode(Options const& options) { #endif - return options.is_causal ? FMHAConfig::run(options) - : FMHAConfig::run(options); + if (options.is_causal) { + return options.use_sink_attn ? FMHAConfig::run(options) + : FMHAConfig::run(options); + } + else { + return options.use_sink_attn ? FMHAConfig::run(options) + : FMHAConfig::run(options); + } } int main(int argc, const char **argv) { diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 6bb7ad487f..3727410a9e 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -60,6 +60,7 @@ struct Options { bool help; bool error; bool is_causal; + bool use_sink_attn; bool varlen = false; bool use_paged_kv = false; std::string scheduler; @@ -68,7 +69,7 @@ struct Options { float softmax_scale; Options() - : help(false), error(false), is_causal(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(1), head_size_qk(128), + : help(false), error(false), is_causal(false), use_sink_attn(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(1), head_size_qk(128), seq_len_kv(512), seq_len_kv_cache(0), page_size(128), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} // Parses the command line @@ -84,6 +85,10 @@ struct Options { is_causal = true; } + if (cmd.check_cmd_line_flag("use_sink_attn")) { + use_sink_attn = true; + } + if (cmd.check_cmd_line_flag("varlen")) { varlen = true; } @@ -120,6 +125,7 @@ struct Options { << "Options:\n\n" << " --help If specified, displays this usage statement\n\n" << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --use_sink_attn Apply Attention Sink\n" << " --varlen Enable variable sequence length\n" << " --scheduler Only Individual Scheduler supported\n" << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" @@ -155,6 +161,7 @@ template struct ExampleRunner { using ElementQ = typename FMHAKernel::ElementQ; using ElementK = typename FMHAKernel::ElementK; using ElementV = typename FMHAKernel::ElementV; + using ElementSink = typename FMHAKernel::ElementSink; using ElementAcc = typename FMHAKernel::ElementAccumulator; using CollectiveEpilogue = typename FMHAKernel::CollectiveEpilogue; @@ -181,6 +188,7 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_Sink; cutlass::DeviceAllocation block_K_cache; cutlass::DeviceAllocation block_V_cache; cutlass::DeviceAllocation block_O; @@ -216,7 +224,7 @@ template struct ExampleRunner { // // Methods // - bool verify(ProblemShapeType problem_size, bool is_causal, bool use_kv_cache) { + bool verify(ProblemShapeType problem_size, bool is_causal, bool use_kv_cache, bool sink_attn) { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<3>(problem_size)); @@ -240,6 +248,13 @@ template struct ExampleRunner { int offset_k_cache = 0; int offset_v_cache = 0; int offset_o = 0; + std::vector host_Sink; + + if (sink_attn) { + host_Sink.resize(block_Sink.size()); + syclcompat::memcpy(host_Sink.data(), block_Sink.get(), host_Sink.size()); + syclcompat::wait(); + } int q_group_size = num_heads_q / num_heads_kv; // loop over the batch dimension to compute the output @@ -352,6 +367,11 @@ template struct ExampleRunner { if (max_vec[max_idx] < host_S[idx]) max_vec[max_idx] = host_S[idx]; } + if (sink_attn) { + ElementAccumulator sink_val = static_cast(host_Sink[h]); + if (max_vec[max_idx] < sink_val) + max_vec[max_idx] = sink_val; + } } // compute exp of S @@ -372,6 +392,12 @@ template struct ExampleRunner { sum_vec[sum_idx] += host_S[idx]; } + if (sink_attn) { + ElementAccumulator sink_val = static_cast(host_Sink[h]); + auto exp_sink = expf((sink_val - max_vec[row]) / sqrt(static_cast((head_size_qk)))); + sum_vec[sum_idx] += exp_sink; + } + // scale each row with the sum to compute softmax idx = row * seq_len_kv_total; sum_idx = row; @@ -553,6 +579,7 @@ template struct ExampleRunner { block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); + block_Sink.reset(num_heads_kv); block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); @@ -592,6 +619,7 @@ template struct ExampleRunner { initialize_block(block_Q, seed + 2021); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2023); + initialize_block(block_Sink, seed + 2021); initialize_block(block_K_cache, seed + 2024); initialize_block(block_V_cache, seed + 2025); @@ -664,7 +692,8 @@ template struct ExampleRunner { block_V_cache.get(), stride_V_cache, options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr, options.use_paged_kv ? paged_kv_cache.page_size : 0, - options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr}, + options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr, + block_Sink.get()}, {options.softmax_scale}, {block_O.get(), stride_O}, hw_info}; @@ -691,7 +720,7 @@ template struct ExampleRunner { // Verify that the result is correct bool use_kv_cache = options.seq_len_kv_cache > 0; - bool passed = verify(problem_size, options.is_causal, use_kv_cache); + bool passed = verify(problem_size, options.is_causal, use_kv_cache, options.use_sink_attn); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) { @@ -742,9 +771,11 @@ template , ElementOutput, GmemTiledCopyStore>; - using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashDecodeSoftmaxEpilogue; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashDecodeSoftmaxEpilogue; using ProblemShapeRegular = cute::tuple; using namespace cutlass::fmha::collective; @@ -778,9 +809,9 @@ template , ElementInputKV, - cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, ElementInputSink, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, GmemTiledCopyQ/* Q */, GmemTiledCopyK/* K */, - GmemTiledCopyV/* V */, Causal, PagedKV>; + GmemTiledCopyV/* V */, Causal, PagedKV, hasSink>; using FMHAKernel = cutlass::flash_attention::kernel::FMHADecode; From 4f95e335076ebe1c650457daa51502730d6fcfb2 Mon Sep 17 00:00:00 2001 From: kareem Date: Mon, 13 Oct 2025 08:01:43 +0300 Subject: [PATCH 2/4] Fix Review comments Signed-off-by: kareem --- .../06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 3727410a9e..8a8c3df37b 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -394,7 +394,7 @@ template struct ExampleRunner { if (sink_attn) { ElementAccumulator sink_val = static_cast(host_Sink[h]); - auto exp_sink = expf((sink_val - max_vec[row]) / sqrt(static_cast((head_size_qk)))); + auto exp_sink = expf((sink_val - max_vec[row]); sum_vec[sum_idx] += exp_sink; } @@ -579,7 +579,7 @@ template struct ExampleRunner { block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); - block_Sink.reset(num_heads_kv); + block_Sink.reset(num_heads_q); block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); From 85f65073627e3dd2f6278717e45b953c81b353d3 Mon Sep 17 00:00:00 2001 From: kareem Date: Mon, 13 Oct 2025 09:12:39 +0300 Subject: [PATCH 3/4] fix compile errors Signed-off-by: kareem --- .../06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 8a8c3df37b..228f244e61 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -394,7 +394,7 @@ template struct ExampleRunner { if (sink_attn) { ElementAccumulator sink_val = static_cast(host_Sink[h]); - auto exp_sink = expf((sink_val - max_vec[row]); + auto exp_sink = expf(sink_val - max_vec[row]); sum_vec[sum_idx] += exp_sink; } From fcf6928bfb862688044f8e5338d28291bf4b44f4 Mon Sep 17 00:00:00 2001 From: kareem Date: Mon, 13 Oct 2025 11:19:42 +0300 Subject: [PATCH 4/4] Add Sink Attention unit test Signed-off-by: kareem --- .../xe_flash_attn_decode_softmax_epilogue.hpp | 2 +- .../kernel/xe_flash_attn_decode.hpp | 2 +- .../bmg_flash_attn_decode_runner.hpp | 4 +-- .../flash_decode_testbed_3x.hpp | 36 ++++++++++++++++--- ...decode_bf16_fp32_fp32_h64_512_nonpaged.cpp | 7 ++++ 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp index bb615b02c5..a7485a8b8e 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp @@ -164,7 +164,7 @@ class FlashDecodeSoftmaxEpilogue(frag_s, shmem_tensor_max, max_val); if constexpr (HasSink) { - if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) { + if (compat::get_nd_item<3>().get_local_linear_id() == 0) { Element max_scale{max_val * params.scale}; sum += sycl::native::exp2((sink_token- max_scale)); } diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp index 6bb44d38ef..b362e54957 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp @@ -345,7 +345,7 @@ class FMHADecode { ElementAccumulator sink_token = ElementAccumulator{ 0 }; auto sum_reg = ElementAccumulator{0}; if constexpr (HasSink) { - if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) { + if (compat::get_nd_item<3>().get_local_linear_id() == 0) { max_reg = static_cast(mainloop_params.ptr_Sink[num_heads_coord]); sink_token = max_reg; } diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 228f244e61..c82a7a973e 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -252,8 +252,8 @@ template struct ExampleRunner { if (sink_attn) { host_Sink.resize(block_Sink.size()); - syclcompat::memcpy(host_Sink.data(), block_Sink.get(), host_Sink.size()); - syclcompat::wait(); + compat::memcpy(host_Sink.data(), block_Sink.get(), host_Sink.size()); + compat::wait(); } int q_group_size = num_heads_q / num_heads_kv; diff --git a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp index 30a09d7a69..c61bda8d63 100644 --- a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp @@ -109,7 +109,7 @@ using GmemTiledCopyStoreU16 = cute::XE_2D_U16x1x16_ST_N; template + typename TiledCopyV, typename TiledCopyStore, bool PagedKV, bool hasSink = false> struct XE_Flash_Attention_Decode { using LayoutQ = cutlass::layout::RowMajor; using LayoutK = cutlass::layout::ColumnMajor; @@ -121,6 +121,7 @@ struct XE_Flash_Attention_Decode { using ElementInputQ = ElementInputType; using ElementInputKV = ElementInputType; using ElementOutput = ElementOutputType; + using ElementInputSink = ElementInputType; using ProblemShapeRegular = cute::tuple; using ProblemShapeVarlen = cute::tuple, ElementOutput, GmemTiledCopyStore>; using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashDecodeSoftmaxEpilogue< - HasCausalMask, EpilogueDispatchPolicy, ElementAccumulator>; + HasCausalMask, hasSink, EpilogueDispatchPolicy, ElementAccumulator>; // Mainloop using CollectiveMainloop = cutlass::flash_attention::collective::FlashDecodeMma< GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, ElementInputKV, - cutlass::gemm::TagToStrideB_t, MMAOperation, + cutlass::gemm::TagToStrideB_t, ElementInputSink, + MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, GmemTiledCopyQ, // Q GmemTiledCopyK, // K GmemTiledCopyV, // V, - HasCausalMask, PagedKV>; + HasCausalMask, PagedKV, hasSink>; using Kernel = cutlass::flash_attention::kernel::FMHADecode; @@ -177,6 +179,7 @@ struct TestbedImpl { using ElementQ = typename FlashDecode::ElementQ; using ElementK = typename FlashDecode::ElementK; using ElementV = typename FlashDecode::ElementV; + using ElementSink = typename FlashDecode::ElementQ; using ElementAcc = typename FlashDecode::ElementAccumulator; using CollectiveMainloop = typename FlashDecode::CollectiveMainloop; @@ -189,6 +192,7 @@ struct TestbedImpl { static constexpr bool HasCausalMask = CollectiveMainloop::CausalMask; static constexpr bool isVarLen = CollectiveMainloop::is_var_len; static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + static constexpr bool HasSink = CollectiveMainloop::HasSink; StrideQ stride_Q; StrideK stride_K; @@ -208,6 +212,7 @@ struct TestbedImpl { cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_Sink; cutlass::DeviceAllocation block_K_cache; cutlass::DeviceAllocation block_V_cache; cutlass::DeviceAllocation block_O; @@ -261,6 +266,7 @@ struct TestbedImpl { block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); + block_Sink.reset(num_heads_q); block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); @@ -300,6 +306,7 @@ struct TestbedImpl { initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); + initialize_block(block_Sink, seed + 2021); initialize_block(block_K_cache, seed + 2024); initialize_block(block_V_cache, seed + 2025); @@ -426,6 +433,13 @@ struct TestbedImpl { int offset_v_cache = 0; int offset_o = 0; + std::vector host_Sink; + if (HasSink) { + host_Sink.resize(block_Sink.size()); + compat::memcpy(host_Sink.data(), block_Sink.get(), host_Sink.size()); + compat::wait(); + } + int q_group_size = num_heads_q / num_heads_kv; // loop over the batch dimension to compute the output // to avoid the risk of running out of device memory @@ -537,6 +551,11 @@ struct TestbedImpl { if (max_vec[max_idx] < host_S[idx]) max_vec[max_idx] = host_S[idx]; } + if (HasSink) { + ElementAccumulator sink_val = static_cast(host_Sink[h]); + if (max_vec[max_idx] < sink_val) + max_vec[max_idx] = sink_val; + } } // compute exp of S @@ -557,6 +576,12 @@ struct TestbedImpl { sum_vec[sum_idx] += host_S[idx]; } + if (HasSink) { + ElementAccumulator sink_val = static_cast(host_Sink[h]); + auto exp_sink = expf(sink_val - max_vec[row]); + sum_vec[sum_idx] += exp_sink; + } + // scale each row with the sum to compute softmax idx = row * seq_len_kv_total; sum_idx = row; @@ -676,7 +701,8 @@ struct TestbedImpl { block_V_cache.get(), stride_V_cache, PagedKV ? paged_kv_cache.page_table.get() : nullptr, PagedKV ? paged_kv_cache.page_size : 0, - PagedKV ? paged_kv_cache.num_pages_per_seq.get() : nullptr}, + PagedKV ? paged_kv_cache.num_pages_per_seq.get() : nullptr, + block_Sink.get()}, {softmax_scale}, {block_O.get(), stride_O}, hw_info}; diff --git a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp index 1dcd52cfc2..669f476e0d 100644 --- a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp +++ b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp @@ -72,4 +72,11 @@ TEST(XE_Flash_Attention_Decode_bf16_fp32_fp32_NonPaged_KVTile512_h64, varlen_non EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64)); } +TEST(XE_Flash_Attention_Decode_bf16_fp32_fp32_NonPaged_KVTile512_h64, sink_attn) { + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64)); +} + } // namespace cutlass