Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "fmha_fusion.hpp"
#include "xe_rotary.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -61,7 +62,7 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_ = false>
struct FlashPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -70,9 +71,9 @@ struct FlashPrefillMma {

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_>
struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> {
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, RopeMask_> {
//
// Type Aliases
//
Expand All @@ -96,6 +97,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
using TiledMmaPV = typename TiledMMAHelper<MmaAtom, Layout<TileShapePV>, SubgroupLayout>::TiledMMA;
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool rope_enabled = RopeMask_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtomShape = typename MmaAtom::Shape_MNK;
Expand Down Expand Up @@ -157,12 +159,19 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
// for RoPE case
ElementQ const *ptr_cos = nullptr;
ElementQ const *ptr_sin = nullptr;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
XE_Copy_Q gmem_tiled_copy_q_cos;
XE_Copy_Q gmem_tiled_copy_q_sin;
XE_Copy_K gmem_tiled_copy_k_cos;
XE_Copy_K gmem_tiled_copy_k_sin;
};

//
Expand All @@ -180,18 +189,87 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV));

auto tensorCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dK));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};

XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorSin)};

XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorSin)};

return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
template <class FragQccum, class TensorQ, class TensorK, class FragSrc, class TensorCosQ, class TensorSinQ, class TensorCosK, class TensorSinK>
CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, FragSrc const &frag_src,
int const &k_tile_count, Params const &params) {
int const &k_tile_count, TensorCosQ gQCos, TensorSinQ gQSin, TensorCosK gKCos, TensorSinK gKSin, Params const &params, ProblemShapeType const &problem_shape) {


int thread_idx = static_cast<int>(ThreadIdxX());
if constexpr (rope_enabled) {
// calculate the base_ptr and offset for Q, K.
// also calculate the layout for Q, K.
// then apply RoPE on Q, K accordingly
auto [coord_q_x, coord_q_y, coord_q_z] = *gQ.data();
auto [coord_k_x, coord_k_y, coord_k_z] = *gK.data();

auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;

int offset_q = seq_len_qo*head_size_qk*coord_q_z + head_size_qk*coord_q_x + coord_q_y; // row major
// int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_y + coord_k_x; // col major
int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_x + coord_k_y; // row major

auto q_traits = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q);
ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr;

auto q_traits_cos = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

// auto layout_q = gQ.layout();
constexpr auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ));
// constexpr auto layout_q = LayoutQ::packed({size<0>(gQ), size<1>(gQ)});
constexpr auto layout_q = make_layout(static_shape_q, LayoutRight{});

auto k_traits = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k);
ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

// auto layout_k = gK.layout();
constexpr auto static_shape_k = make_shape(size<0>(gK), size<1>(gK));
constexpr auto layout_k = make_layout(static_shape_k, LayoutRight{});

for (int i =0 ;i< size<2>(gQ) && thread_idx< size<0>(gQ); i++){
auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q);
auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q);
auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ);
offset_q += QK_BLK_M*QK_BLK_K;
}

for (int i =0 ;i< size<2>(gK) && thread_idx< size<0>(gK); i++){
auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+offset_k), layout_k);
auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+offset_k), layout_k);
auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+offset_k), layout_k);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK);
offset_k += QK_BLK_N*QK_BLK_K;
}
}


auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx);
auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx);
// Instantiate the MMA object
Expand Down
43 changes: 41 additions & 2 deletions applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class FMHAPrefill {
using AccumeShape = decltype(make_shape(Int<Vec>{}, Int<FragsM>{}, get<1>(TileShapePV{})/get<1>(MmaAtomShape()), Int<VSlicer>{}));

static constexpr bool is_var_len = CollectiveMainloop::is_var_len;
static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled;

// Kernel level shared memory storage
struct SharedStorage {
Expand Down Expand Up @@ -272,10 +273,24 @@ class FMHAPrefill {
Tensor mK_nk = mK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mV_nk = mV_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)

Tensor mCosQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mSinQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mCosK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mSinK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mCosQ_mk = mCosQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mSinQ_mk = mSinQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mCosK_nk = mCosK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mSinK_nk = mSinK_nkl(_, _, blk_l_coord/group_heads_q);

auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gV = local_tile(mV_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step<X, _1, _1>{});

auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gCosK = local_tile(mCosK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gSinK = local_tile(mSinK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});

auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, sequence_length_shape, batch_coord);
// we limit the horisontal size to two subgroup, the empirical resutls show that reading the two cacheline side by side in gives better performance and
// anything after that does not have an effect on performance. // (64 here for float b float when possible and loop over to cover all the data needed)
Expand All @@ -289,6 +304,12 @@ class FMHAPrefill {
auto pKgK = thr_prefetch_K.partition_S(gK);
auto pVgV = thr_prefetch_V.partition_S(gV);

// RoPE coordinate tensor partitions
auto pCosQgCosQ = thr_prefetch_Q.partition_S(gCosQ);
auto pSinQgSinQ = thr_prefetch_Q.partition_S(gSinQ);
auto pCosKgCosK = thr_prefetch_K.partition_S(gCosK);
auto pSinKgSinK = thr_prefetch_K.partition_S(gSinK);

for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
}
Expand All @@ -299,6 +320,18 @@ class FMHAPrefill {
}
}

for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pCosQgCosQ(_, _, _, i));
prefetch(tiled_prefetch_q, pSinQgSinQ(_, _, _, i));
}
for (int j = 0; j < size<4>(pKgK); j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < DispatchPolicy::Stages; i++) {
prefetch(tiled_prefetch_k, pCosKgCosK(_, _, _ , i, j));
prefetch(tiled_prefetch_k, pSinKgSinK(_, _, _ , i, j));
}
}

// Allocate the tiled_mma and the accumulators for the (M,N) workgroup_shape
Tensor out_reg = make_tensor<ElementAccumulator>(AccumeShape{});

Expand All @@ -325,7 +358,7 @@ class FMHAPrefill {
clear(tSr);

// 3) Perform GEMM S = Q*K
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params);
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock, _), tSr, ceil_div(head_size_qk, QK_BLK_K), gCosQ, gSinQ, gCosK(_, _, nblock, _), gSinK(_, _, nblock, _), mainloop_params, params.problem_shape);

// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
Expand All @@ -343,6 +376,12 @@ class FMHAPrefill {
for (int j = 0; j < size<4>(pKgK); j++) {
prefetch(tiled_prefetch_k, pKgK(_, _, _, nblock + DispatchPolicy::Stages, j));
}

for (int j = 0; j < size<4>(pKgK); j++) {
prefetch(tiled_prefetch_k, pCosKgCosK(_, _, _, nblock + DispatchPolicy::Stages, j));
prefetch(tiled_prefetch_k, pSinKgSinK(_, _, _, nblock + DispatchPolicy::Stages, j));
}

barrier_wait(barrier_scope);
}

Expand All @@ -351,7 +390,7 @@ class FMHAPrefill {
Tensor tSr = make_tensor<ElementAccumulator>(Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
clear(tSr);
// 3) Perform GEMM S = Q*K
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_limit - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params);
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_limit - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), gCosQ, gSinQ, gCosK(_, _, nblock_limit - 1, _), gSinK(_, _, nblock_limit - 1, _), mainloop_params, params.problem_shape);
// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
for(int i=0; i< size<1>(pVgV); i++) {
Expand Down
Loading