|
| 1 | +#include "flash_attention_v2/collective/fmha_fusion.hpp" |
| 2 | +#include "flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp" |
| 3 | +#include "flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp" |
| 4 | +#include "flash_attention_v2/collective/xe_flash_attn_prefill_softmax_epilogue.hpp" |
| 5 | +#include "flash_attention_v2/kernel/tile_scheduler.hpp" |
| 6 | +#include "flash_attention_v2/kernel/xe_flash_attn_prefill.hpp" |
| 7 | + |
| 8 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 9 | + |
| 10 | +#include <exception> |
| 11 | +#include <iostream> |
| 12 | + |
| 13 | +//////////////////////////////////////////////////////////////////////////////// |
| 14 | +// PRIVATE FUNCTION |
| 15 | +//////////////////////////////////////////////////////////////////////////////// |
| 16 | + |
| 17 | +template <typename FMHA> static auto run(typename FMHA::Params params) -> void { |
| 18 | + cute::dim3 const block = FMHA::get_block_shape(); |
| 19 | + cute::dim3 const grid = FMHA::get_grid_shape(params); |
| 20 | + |
| 21 | + int smem_size = FMHA::SharedStorageSize; |
| 22 | + |
| 23 | + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); |
| 24 | + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); |
| 25 | + |
| 26 | +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) |
| 27 | + using namespace syclcompat::experimental; |
| 28 | + auto event = launch<cutlass::device_kernel<FMHA>>( |
| 29 | + launch_policy{ |
| 30 | + sycl_grid, sycl_block, |
| 31 | + local_mem_size{static_cast<std::size_t>(smem_size)}, |
| 32 | + kernel_properties{ |
| 33 | + sycl_exp::sub_group_size<FMHA::DispatchPolicy::SubgroupSize>}}, |
| 34 | + params); |
| 35 | +#else |
| 36 | + syclcompat::experimental::launch_properties launch_props{ |
| 37 | + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), |
| 38 | + }; |
| 39 | + syclcompat::experimental::kernel_properties kernel_props{ |
| 40 | + sycl::ext::oneapi::experimental::sub_group_size< |
| 41 | + FMHA::DispatchPolicy::SubgroupSize>}; |
| 42 | + syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, |
| 43 | + launch_props, kernel_props}; |
| 44 | + auto event = syclcompat::experimental::launch<cutlass::device_kernel<FMHA>>( |
| 45 | + policy, params); |
| 46 | +#endif |
| 47 | + |
| 48 | + EventManager::getInstance().addEvent(event); |
| 49 | +} |
| 50 | + |
| 51 | +template <bool Causal, typename TileShapeQK, typename TileShapePV, |
| 52 | + typename TileShapeOutput, typename SubgroupLayout, int PipelineStages> |
| 53 | +static auto attention_run(const at::Tensor &Q, const at::Tensor &K, |
| 54 | + const at::Tensor &V, at::Tensor &O, int Batch, |
| 55 | + int NumHeadsQ, int NumHeadsKV, int SeqLengthQO, |
| 56 | + int SeqLengthKV, int HeadSizeQK, int HeadSizeVO, |
| 57 | + float sm_scale) -> int { |
| 58 | + RECORD_FUNCTION("cutlass fa", {}); |
| 59 | + |
| 60 | + using ElementAccumulator = float; |
| 61 | + using ElementInputQ = cutlass::half_t; |
| 62 | + using ElementInputKV = cutlass::half_t; |
| 63 | + using ElementOutput = float; |
| 64 | + |
| 65 | + using LayoutQ = cutlass::layout::RowMajor; |
| 66 | + using LayoutK = cutlass::layout::ColumnMajor; |
| 67 | + using LayoutV = cutlass::layout::RowMajor; |
| 68 | + using LayoutO = cutlass::layout::RowMajor; |
| 69 | + |
| 70 | + using GEMMDispatchPolicy = |
| 71 | + cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>; |
| 72 | + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; |
| 73 | + |
| 74 | + using MMAOperation = cute::XE_8x16x16_F32F16F16F32_TT; |
| 75 | + |
| 76 | + using GmemTiledCopyQ = cute::XE_2D_U16x8x32_LD_N; |
| 77 | + using GmemTiledCopyK = cute::XE_2D_U16x16x16_LD_T; |
| 78 | + using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V; |
| 79 | + using GmemTiledCopyStore = cute::XE_2D_U32x8x16_ST_N; |
| 80 | + |
| 81 | + using ProblemShapeType = cute::tuple<int, int, int, int, int, int, int>; |
| 82 | + |
| 83 | + /// MAIN LOOP /// |
| 84 | + |
| 85 | + using CollectiveMainloop = |
| 86 | + cutlass::flash_attention::collective::FlashPrefillMma< |
| 87 | + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, |
| 88 | + cutlass::gemm::TagToStrideA_t<LayoutQ>, ElementInputKV, |
| 89 | + cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV, |
| 90 | + cutlass::gemm::TagToStrideB_t<LayoutV>, MMAOperation, TileShapeQK, |
| 91 | + TileShapePV, SubgroupLayout, |
| 92 | + GmemTiledCopyQ, // Q |
| 93 | + GmemTiledCopyK, // K |
| 94 | + GmemTiledCopyV, // V, |
| 95 | + Causal>; |
| 96 | + |
| 97 | + /// EPILOGUE LOOP /// |
| 98 | + |
| 99 | + using CollectiveSoftmaxEpilogue = |
| 100 | + cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue< |
| 101 | + Causal, EpilogueDispatchPolicy, ElementAccumulator>; |
| 102 | + using CollectiveEpilogue = |
| 103 | + cutlass::flash_attention::collective::FlashPrefillEpilogue< |
| 104 | + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, |
| 105 | + ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>, |
| 106 | + ElementOutput, GmemTiledCopyStore>; |
| 107 | + |
| 108 | + /// FA /// |
| 109 | + |
| 110 | + using FMHAPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefill< |
| 111 | + ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue, |
| 112 | + CollectiveEpilogue>; |
| 113 | + |
| 114 | + /// FA INVOCATION /// |
| 115 | + |
| 116 | + try { |
| 117 | + /// Buffer Initialization |
| 118 | + const cutlass::half_t *_Q = |
| 119 | + static_cast<const cutlass::half_t *>(Q.data_ptr()); |
| 120 | + const cutlass::half_t *_K = |
| 121 | + static_cast<const cutlass::half_t *>(K.data_ptr()); |
| 122 | + const cutlass::half_t *_V = |
| 123 | + static_cast<const cutlass::half_t *>(V.data_ptr()); |
| 124 | + const float *_O = static_cast<const float *>(O.data_ptr()); |
| 125 | + |
| 126 | + /// Problem size |
| 127 | + using ProblemShapeType = typename FMHAPrefillKernel::ProblemShape; |
| 128 | + ProblemShapeType problem_size = |
| 129 | + ProblemShapeType{Batch, NumHeadsQ, NumHeadsKV, SeqLengthQO, |
| 130 | + SeqLengthKV, HeadSizeQK, HeadSizeVO}; |
| 131 | + |
| 132 | + /// Stride |
| 133 | + using StrideQ = typename FMHAPrefillKernel::StrideQ; |
| 134 | + using StrideK = typename FMHAPrefillKernel::StrideK; |
| 135 | + using StrideV = typename FMHAPrefillKernel::StrideV; |
| 136 | + using StrideO = typename FMHAPrefillKernel::StrideO; |
| 137 | + StrideQ stride_Q = cutlass::make_cute_packed_stride( |
| 138 | + StrideQ{}, |
| 139 | + cute::make_shape(SeqLengthQO, HeadSizeQK, Batch * NumHeadsQ)); |
| 140 | + StrideK stride_K = cutlass::make_cute_packed_stride( |
| 141 | + StrideK{}, |
| 142 | + cute::make_shape(SeqLengthKV, HeadSizeQK, Batch * NumHeadsKV)); |
| 143 | + StrideV stride_V = cutlass::make_cute_packed_stride( |
| 144 | + StrideV{}, |
| 145 | + cute::make_shape(HeadSizeVO, SeqLengthKV, Batch * NumHeadsKV)); |
| 146 | + StrideO stride_O = cutlass::make_cute_packed_stride( |
| 147 | + StrideO{}, |
| 148 | + cute::make_shape(SeqLengthQO, HeadSizeVO, Batch * NumHeadsQ)); |
| 149 | + |
| 150 | + static cutlass::KernelHardwareInfo hw_info; |
| 151 | + if (hw_info.sm_count == 0) { |
| 152 | + hw_info.sm_count = |
| 153 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); |
| 154 | + CUTLASS_TRACE_HOST( |
| 155 | + "Query result for SM count per device: " << hw_info.sm_count); |
| 156 | + } |
| 157 | + |
| 158 | + typename FMHAPrefillKernel::Arguments arguments = { |
| 159 | + cutlass::gemm::GemmUniversalMode::kGemm, |
| 160 | + problem_size, |
| 161 | + {_Q, stride_Q, _K, stride_K, _V, stride_V}, |
| 162 | + {sm_scale}, |
| 163 | + {_O, stride_O}, |
| 164 | + hw_info}; |
| 165 | + |
| 166 | + size_t workspace_size = FMHAPrefillKernel::get_workspace_size(arguments); |
| 167 | + cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 168 | + auto workspace_ptr = workspace.get(); |
| 169 | + |
| 170 | + if (!FMHAPrefillKernel::can_implement(arguments)) { |
| 171 | + std::cout << "Invalid Problem Size: " << Batch << 'x' << NumHeadsQ << 'x' |
| 172 | + << SeqLengthQO << 'x' << SeqLengthKV << 'x' << HeadSizeQK << 'x' |
| 173 | + << HeadSizeVO << (Causal ? "xCausal" : "xNonCausal") |
| 174 | + << std::endl; |
| 175 | + return -1; |
| 176 | + } |
| 177 | + |
| 178 | + CUTLASS_CHECK( |
| 179 | + FMHAPrefillKernel::initialize_workspace(arguments, workspace_ptr)); |
| 180 | + auto params = |
| 181 | + FMHAPrefillKernel::to_underlying_arguments(arguments, workspace_ptr); |
| 182 | + run<FMHAPrefillKernel>(params); |
| 183 | + |
| 184 | + syclcompat::wait(); |
| 185 | + |
| 186 | + } catch (std::exception &e) { |
| 187 | + std::cerr << "Runtime error: " << e.what() << std::endl; |
| 188 | + return -1; |
| 189 | + } catch (...) { |
| 190 | + std::cerr << "Unexpected error" << std::endl; |
| 191 | + return -1; |
| 192 | + } |
| 193 | + |
| 194 | + return 0; |
| 195 | +} |
| 196 | + |
| 197 | +//////////////////////////////////////////////////////////////////////////////// |
| 198 | +// PUBLIC FUNCTION |
| 199 | +//////////////////////////////////////////////////////////////////////////////// |
| 200 | + |
| 201 | +using FARunPtr = int (*)(const at::Tensor &Q, const at::Tensor &K, |
| 202 | + const at::Tensor &V, at::Tensor &O, int Batch, |
| 203 | + int NumHeadsQ, int NumHeadsKV, int SeqLengthQO, |
| 204 | + int SeqLengthKV, int HeadSizeQK, int HeadSizeVO, |
| 205 | + float sm_scale); |
| 206 | + |
| 207 | +auto attention(const at::Tensor &Q, const at::Tensor &K, const at::Tensor &V, |
| 208 | + at::Tensor &O, int Batch, int NumHeadsQ, int NumHeadsKV, |
| 209 | + int SeqLengthQO, int SeqLengthKV, int HeadSizeQK, int HeadSizeVO, |
| 210 | + bool Causal, float sm_scale) -> int { |
| 211 | + constexpr int PipelineStages = 2; |
| 212 | + FARunPtr f = nullptr; |
| 213 | + |
| 214 | + if (HeadSizeVO == 64) { |
| 215 | + using ShapeQK = cute::Shape<cute::_128, cute::_64, cute::_64>; |
| 216 | + using ShapePV = cute::Shape<cute::_128, cute::_32, cute::_64>; |
| 217 | + using ShapeOutPut = cute::Shape<cute::_128, cute::_64, cute::_64>; |
| 218 | + using SubgroupLayout = |
| 219 | + cute::Layout<cute::Shape<cute::_8, cute::_1, cute::_1>, |
| 220 | + cute::Stride<cute::_1, cute::_1, cute::_1>>; |
| 221 | + |
| 222 | + f = Causal ? attention_run<true, ShapeQK, ShapePV, ShapeOutPut, |
| 223 | + SubgroupLayout, PipelineStages> |
| 224 | + : attention_run<false, ShapeQK, ShapePV, ShapeOutPut, |
| 225 | + SubgroupLayout, PipelineStages>; |
| 226 | + |
| 227 | + } else if (HeadSizeVO == 128) { |
| 228 | + using ShapeQK = cute::Shape<cute::_128, cute::_64, cute::_64>; |
| 229 | + using ShapePV = cute::Shape<cute::_128, cute::_32, cute::_64>; |
| 230 | + using ShapeOutPut = cute::Shape<cute::_128, cute::_128, cute::_64>; |
| 231 | + using SubgroupLayout = |
| 232 | + cute::Layout<cute::Shape<cute::_16, cute::_1, cute::_1>, |
| 233 | + cute::Stride<cute::_1, cute::_1, cute::_1>>; |
| 234 | + |
| 235 | + f = Causal ? attention_run<true, ShapeQK, ShapePV, ShapeOutPut, |
| 236 | + SubgroupLayout, PipelineStages> |
| 237 | + : attention_run<false, ShapeQK, ShapePV, ShapeOutPut, |
| 238 | + SubgroupLayout, PipelineStages>; |
| 239 | + } else { |
| 240 | + std::cerr << "Unsupported HeadSizeVO: " << HeadSizeVO << std::endl; |
| 241 | + return -1; |
| 242 | + } |
| 243 | + |
| 244 | + return f(Q, K, V, O, Batch, NumHeadsQ, NumHeadsKV, SeqLengthQO, SeqLengthKV, |
| 245 | + HeadSizeQK, HeadSizeVO, sm_scale); |
| 246 | +} |
0 commit comments