|
| 1 | +/****************************************************************************** |
| 2 | + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. |
| 3 | + ******************************************************************************/ |
| 4 | + |
| 5 | +#pragma once |
| 6 | +#include "cute/algorithm/copy.hpp" |
| 7 | +#include "cute/atom/mma_atom.hpp" |
| 8 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 9 | + |
| 10 | +#include "cutlass/cutlass.h" |
| 11 | +#include "cutlass/layout/layout.h" |
| 12 | +#include "cutlass/numeric_types.h" |
| 13 | +#include "cutlass/pipeline/pipeline.hpp" |
| 14 | +#include "cutlass/cluster_launch.hpp" |
| 15 | +#include "cutlass/arch/reg_reconfig.h" |
| 16 | + |
| 17 | +#include "kernel_traits.h" |
| 18 | +#include "mainloop_attn.hpp" |
| 19 | +#include "softmax.hpp" |
| 20 | + |
| 21 | +using namespace cute; |
| 22 | + |
| 23 | +template <int kHeadDim> |
| 24 | +auto get_gmem_layout(int token_num, int head_num) { |
| 25 | + return make_layout( |
| 26 | + make_shape(token_num, kHeadDim, head_num), |
| 27 | + make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim)); |
| 28 | +} |
| 29 | + |
| 30 | +template <typename Ktraits> |
| 31 | +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) |
| 32 | + compute_attn_ws( |
| 33 | + CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params, |
| 34 | + CUTE_GRID_CONSTANT Flash_mask_params const data_params) { |
| 35 | + |
| 36 | + using Element = typename Ktraits::Element; |
| 37 | + using ElementAccum = typename Ktraits::ElementAccum; |
| 38 | + using SoftType = ElementAccum; |
| 39 | + using TileShape_MNK = typename Ktraits::TileShape_MNK; |
| 40 | + using ClusterShape = typename Ktraits::ClusterShape_MNK; |
| 41 | + |
| 42 | + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); |
| 43 | + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; |
| 44 | + static constexpr int kBlockM = Ktraits::kBlockM; |
| 45 | + static constexpr int kBlockN = Ktraits::kBlockN; |
| 46 | + constexpr int kHeadDim = Ktraits::kHeadDim; |
| 47 | + constexpr bool NeedMask = Ktraits::NeedMask; |
| 48 | + |
| 49 | + using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>; |
| 50 | + |
| 51 | + using MainloopPipeline = typename Ktraits::MainloopPipeline; |
| 52 | + using PipelineParams = typename MainloopPipeline::Params; |
| 53 | + using PipelineState = typename MainloopPipeline::PipelineState; |
| 54 | + |
| 55 | + extern __shared__ char shared_memory[]; |
| 56 | + auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory); |
| 57 | + |
| 58 | + __align__(16) __shared__ int mask[kBlockM]; |
| 59 | + |
| 60 | + const int m_block = blockIdx.x; |
| 61 | + const int bidh = blockIdx.y; |
| 62 | + const int bidb = blockIdx.z; |
| 63 | + |
| 64 | + if constexpr (NeedMask) { |
| 65 | + const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM; |
| 66 | + |
| 67 | + for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) { |
| 68 | + mask[i] = mask_this_batch[i]; |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + const int seq_len_q = data_params.seq_len_encoder[bidb]; |
| 73 | + const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb]; |
| 74 | + |
| 75 | + if (m_block * kBlockM >= seq_len_q) { |
| 76 | + return; |
| 77 | + } |
| 78 | + |
| 79 | + int const lane_predicate = cute::elect_one_sync(); |
| 80 | + int const warp_idx = cutlass::canonical_warp_idx_sync(); |
| 81 | + |
| 82 | + if (warp_idx == 0 && lane_predicate) { |
| 83 | + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); |
| 84 | + } |
| 85 | + |
| 86 | + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; |
| 87 | + |
| 88 | + PipelineParams pipeline_params; |
| 89 | + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; |
| 90 | + int warp_group_idx = cutlass::canonical_warp_group_idx(); |
| 91 | + pipeline_params.role = warp_group_idx == 0 |
| 92 | + ? MainloopPipeline::ThreadCategory::Producer |
| 93 | + : MainloopPipeline::ThreadCategory::Consumer; |
| 94 | + pipeline_params.is_leader = warp_group_thread_idx == 0; |
| 95 | + pipeline_params.num_consumers = NumMmaThreads; |
| 96 | + |
| 97 | + if (warp_idx == 0 && lane_predicate) { |
| 98 | + shared_storage.barrier_Q.init(1); |
| 99 | + } |
| 100 | + |
| 101 | + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); |
| 102 | + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); |
| 103 | + |
| 104 | + __syncthreads(); |
| 105 | + |
| 106 | + CollectiveMainloop collective_mainloop; |
| 107 | + |
| 108 | + const int real_seq = seq_len_q - m_block * kBlockM; |
| 109 | + |
| 110 | + const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN); |
| 111 | + |
| 112 | + if (warp_group_idx == 0) { // Producer |
| 113 | + cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>(); |
| 114 | + |
| 115 | + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); |
| 116 | + if (warp_idx_in_warpgroup == 0) { // Load Q, K, V |
| 117 | + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>(); |
| 118 | + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>(); |
| 119 | + |
| 120 | + collective_mainloop.load( |
| 121 | + mainloop_params, |
| 122 | + pipeline_k, |
| 123 | + pipeline_v, |
| 124 | + smem_pipe_write_k, |
| 125 | + smem_pipe_write_v, |
| 126 | + shared_storage, |
| 127 | + n_block_max, |
| 128 | + m_block, |
| 129 | + bidh, |
| 130 | + bidb, |
| 131 | + data_params.cu_seq_q, |
| 132 | + data_params.cu_seq_k, |
| 133 | + seq_len_q, |
| 134 | + seq_len_k); |
| 135 | + } |
| 136 | + } else { // Consumer |
| 137 | + cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>(); |
| 138 | + typename Ktraits::TiledMma1 tiled_mma1; |
| 139 | + |
| 140 | + PipelineState smem_pipe_read_k, smem_pipe_read_v; |
| 141 | + |
| 142 | + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); |
| 143 | + Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; |
| 144 | + |
| 145 | + collective_mainloop.mma( |
| 146 | + mainloop_params, |
| 147 | + pipeline_k, |
| 148 | + pipeline_v, |
| 149 | + smem_pipe_read_k, |
| 150 | + smem_pipe_read_v, |
| 151 | + tOrO, |
| 152 | + softmax, |
| 153 | + mask, |
| 154 | + n_block_max, |
| 155 | + threadIdx.x - NumCopyThreads, |
| 156 | + m_block, |
| 157 | + seq_len_q, |
| 158 | + seq_len_k, |
| 159 | + shared_storage); |
| 160 | + |
| 161 | + const int o_head_stride = data_params.head_num * kHeadDim; |
| 162 | + const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim; |
| 163 | + |
| 164 | + collective_mainloop.store<NumMmaThreads>( |
| 165 | + mainloop_params, |
| 166 | + tOrO, |
| 167 | + shared_storage, |
| 168 | + tiled_mma1, |
| 169 | + threadIdx.x - NumCopyThreads, |
| 170 | + o_head_stride, |
| 171 | + real_seq, |
| 172 | + reinterpret_cast<Element*>(data_params.o_ptr) + store_offset); |
| 173 | + } |
| 174 | + |
| 175 | +} |
| 176 | + |
| 177 | + |
| 178 | +template<typename Kernel_traits> |
| 179 | +void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) { |
| 180 | + using Element = typename Kernel_traits::Element; |
| 181 | + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; |
| 182 | + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; |
| 183 | + |
| 184 | + using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>; |
| 185 | + constexpr int kHeadDim = Kernel_traits::kHeadDim; |
| 186 | + |
| 187 | + typename CollectiveMainloop::Params mainloop_params = |
| 188 | + CollectiveMainloop::to_underlying_arguments({ |
| 189 | + static_cast<Element const*>(params.q_ptr), |
| 190 | + get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num), |
| 191 | + static_cast<Element const*>(params.k_ptr), |
| 192 | + get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num), |
| 193 | + static_cast<Element const*>(params.v_ptr), |
| 194 | + get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num), |
| 195 | + params.scale_softmax_log2 |
| 196 | + }); |
| 197 | + |
| 198 | + int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM); |
| 199 | + |
| 200 | + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); |
| 201 | + |
| 202 | + void *kernel; |
| 203 | + kernel = (void *)compute_attn_ws<Kernel_traits>; |
| 204 | + int smem_size = sizeof(typename Kernel_traits::SharedStorage); |
| 205 | + |
| 206 | + if (smem_size >= 48 * 1024) { |
| 207 | + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); |
| 208 | + } |
| 209 | + |
| 210 | + dim3 grid_dims; |
| 211 | + grid_dims.x = num_blocks_m; |
| 212 | + grid_dims.y = params.head_num; |
| 213 | + grid_dims.z = params.batch_size; |
| 214 | + |
| 215 | + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; |
| 216 | + dim3 block_dims(ctaSize); |
| 217 | + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); |
| 218 | + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; |
| 219 | + cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params); |
| 220 | +} |
| 221 | + |
| 222 | +template <int kBlockM, int kBlockN, bool NeedMask, typename InputType> |
| 223 | +void flash_attn_headdim128(Flash_mask_params ¶ms, cudaStream_t stream) { |
| 224 | + |
| 225 | + constexpr static int Headdim = 128; |
| 226 | + constexpr static int kNWarps = kBlockM / 16 + 4; |
| 227 | + constexpr static int kStages = 2; |
| 228 | + |
| 229 | + using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>; |
| 230 | + run_flash_mask<Ktraits>(params, stream); |
| 231 | +} |
0 commit comments