diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 43d345bcb480c..8c5132193e0e0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5966,20 +5966,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class JambaModel(TextModel): model_arch = gguf.MODEL_ARCH.JAMBA - def get_vocab_base_pre(self, tokenizer) -> str: - del tokenizer # unused - - return "gpt-2" - def set_vocab(self): if (self.dir_model / "tokenizer.model").is_file(): - # Using Jamba's tokenizer.json causes errors on model load - # (something about "byte not found in vocab"), - # but there's a working tokenizer.model self._set_vocab_sentencepiece() else: - # Some Jamba models only have a tokenizer.json, which works. - self._set_vocab_gpt2() + self._set_vocab_llama_hf() + self.gguf_writer.add_add_space_prefix(False) def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index bdcefe7b7ed7a..3024775135966 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d51abbeafa944..e0abde5427c83 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -245,7 +245,8 @@ static bool fp16_available(const int cc) { } static bool fast_fp16_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, } // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { if constexpr (alignment != 0) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 33d2f0f49e3de..bc0c2523cc82f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -793,8 +793,6 @@ void launch_fattn( GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -878,7 +876,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -916,8 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -946,7 +943,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 68de623d80349..3a5806d9091d7 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,756 +1,45 @@ #include "common.cuh" -#include "fattn-common.cuh" #include "fattn-tile.cuh" #include "fattn-wmma-f16.cuh" -// kq_stride == number of KQ rows to process per iteration -// kq_nbatch == number of K columns to load in parallel for KQ calculation - -static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { - if (GGML_CUDA_CC_IS_AMD(cc)) { - if (GGML_CUDA_CC_IS_RDNA(cc)) { - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - if (fast_fp16_available(cc)) { - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - GGML_UNUSED(warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP -#ifdef RDNA - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // RDNA -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else - switch (D) { - case 64: - return 64; - case 128: - return 128; - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { - return 256; - GGML_UNUSED_VARS(cc, ncols); -} - -static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { - return 256; - GGML_UNUSED(ncols); -} - -static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { -#ifdef RDNA - return 3; -#else - return ncols <= 16 ? 3 : 2; -#endif // RDNA - GGML_UNUSED(ncols); -} - -template // D == head size -__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) -static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef GGML_USE_WMMA_FATTN - NO_DEVICE_CODE; - return; -#endif // GGML_USE_WMMA_FATTN - - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - constexpr int warp_size = 32; - constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; - constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); - static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); - constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); - static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); - constexpr int cpy_ne = cpy_nb / 4; - - constexpr int cpw = ncols/nwarps; // cols per warp - - // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. - // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. -#ifdef FAST_FP16_AVAILABLE - constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; - - __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ half2 Q_tmp[ncols][D/2]; - __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. - half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#else - constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; - - __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ float Q_tmp[ncols][D]; - __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. - float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#endif // FAST_FP16_AVAILABLE - static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); - - float KQ_max[cpw]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - KQ_max[j0/nwarps] = -FLT_MAX/2.0f; - } - float KQ_sum[cpw] = {0.0f}; - - // Load Q data, convert to FP16 if fast. -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - const int j = j0 + threadIdx.y*cpw; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - float tmp_f[cpy_ne_D] = {0.0f}; - if (ic0 + j < ne01) { - ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); - } - -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp_f[i1] *= scale; - } - -#ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { - tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); - } - ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); -#else - ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); -#endif // FAST_FP16_AVAILABLE - } - } - - __syncthreads(); - - // Main loop over KV cache: - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float KQ_max_new[cpw]; -#pragma unroll - for (int j = 0; j < cpw; ++j) { - KQ_max_new[j] = KQ_max[j]; - } - - float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. - - // KQ = K @ Q matrix multiplication: -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], - &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); - } -#else - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { - half2 tmp_h2[cpy_ne_kqnb/2]; - ggml_cuda_memcpy_1( - tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); - - float2 tmp_f2[cpy_ne_kqnb/2]; -#pragma unroll - for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { - tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); - } - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { - half2 K_k[kq_stride/warp_size][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; -#else -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { - float K_k[kq_stride/warp_size][cpy_ne]; - float Q_k[cpw][cpy_ne]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { -#pragma unroll - for (int k = 0; k < cpy_ne; ++k) { - ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); - } - } - } - } - - if (k_KQ_0 + kq_nbatch < D) { - __syncthreads(); // Sync not needed on last iteration. - } - } - - // Apply logit softcap, mask, update KQ_max: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - - if (use_logit_softcap) { - KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - - KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - } - - __syncthreads(); - - // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { -#ifdef FAST_FP16_AVAILABLE - half tmp[kq_stride/warp_size][softmax_iter_j]; -#else - float tmp[kq_stride/warp_size][softmax_iter_j]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); - const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); - KQ_max[j0+j1] = KQ_max_new[j0+j1]; - - float KQ_sum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); - KQ_sum_add += val; - tmp[i0/warp_size][j1] = val; - } - KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; - VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - ggml_cuda_memcpy_1( - KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); - } - } - - // VKQ = V @ KQ matrix multiplication: - constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. - static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); -#pragma unroll - for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { - const int k_tile = k1 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], - &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - half2 tmp_h2[cpy_ne_D/2]; - ggml_cuda_memcpy_1( - tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); - - float2 tmp_f2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { - tmp_f2[i1] = __half22float2(tmp_h2[i1]); - } - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - half2 V_k[(D/2)/warp_size]; - half2 KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - half tmp[softmax_iter_j]; - ggml_cuda_memcpy_1( - &tmp, KQ[j][k0 + k1]); -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_k[j0+j1] = __half2half2(tmp[j1]); - } - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; - } - } - } -#else -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - float2 V_k[(D/2)/warp_size]; - float KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - ggml_cuda_memcpy_1( - &KQ_k[j0], KQ[j][k0 + k1]); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; - VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; - } - } - } -#endif // FAST_FP16_AVAILABLE - - __syncthreads(); - } - } - - - // Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - float KQ_max_new_j = fmaxf(KQ_max[j0], sink); - KQ_max_new_j = warp_reduce_max(KQ_max_new_j); - - const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); - KQ_max[j0] = KQ_max_new_j; - - const float val = expf(sink - KQ_max[j0]); - KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; - if (threadIdx.x == 0) { - KQ_sum[j0] += val; - } - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size].x *= KQ_max_scale; - VKQ[j0][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - } - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); - } - if (gridDim.y == 1) { -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; - } -#else - const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; - VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - - // Write back results: -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); - } - ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); - } -#endif // FAST_FP16_AVAILABLE - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int warp_size = 32; - - constexpr size_t nbytes_shared = 0; - -#ifdef GGML_USE_HIP - if constexpr (D <= 128) { - if (Q->ne[1] > 32) { - constexpr int cols_per_block = 64; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - } -#endif // GGML_USE_HIP - - if (Q->ne[1] > 16) { - constexpr int cols_per_block = 32; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - - constexpr int cols_per_block = 16; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); -} - -template -static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; case 64: { - launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); } break; case 128: { - launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; case 256: { - launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; default: { GGML_ABORT("Unsupported head size"); } break; } } - -void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_switch_head_size(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_switch_head_size(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 10dc22d1bf971..2efc9cc880cf8 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1,3 +1,1216 @@ #include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); + if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { + KQ_sum_add += val; + } + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + half2 V_k[(DVp/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + half tmp[KQ_cs]; + ggml_cuda_memcpy_1( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + float2 V_k[(DVp/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ncols1 == 1 || col_Q_0 + j < ne01) { + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + if (gridDim.y == 1) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; + } +#else + const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; + VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= ne01) { + return; + } + + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 1848d08836185..7235f1b77aeed 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0c8e7b3e41904..fe970adaecef3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; #endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int gqa_ratio = Q->ne[2] / K->ne[2]; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - const int cc = ggml_cuda_info().devices[device].cc; + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - // TODO: temporary until support is extended - // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 - if (K->ne[1] % FATTN_KQ_STRIDE != 0) { - return BEST_FATTN_KERNEL_NONE; - } + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + + const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { + case 40: case 64: - case 128: - case 256: - if (V->ne[0] != K->ne[0]) { - return BEST_FATTN_KERNEL_NONE; - } - break; case 80: case 96: + case 128: case 112: + case 256: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { - return BEST_FATTN_KERNEL_NONE; - } break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { return BEST_FATTN_KERNEL_NONE; } break; @@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; - - // If Turing tensor cores available, use them except for some cases with batch size 1: - if (turing_mma_available(cc)) { - best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // If Turing tensor cores available, use them: + if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { if (Q->ne[1] <= 2) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } } - if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; } } - return best; - } - - // Use kernels specialized for small batch sizes if possible: - if (Q->ne[1] <= 8 && can_use_vector_kernel) { - return BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_MMA_F16; } - // For large batch sizes, use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc)) { + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } return BEST_FATTN_KERNEL_WMMA_F16; } - // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } return BEST_FATTN_KERNEL_TILE; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fb691528b7de4..856e9de2e1115 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3867,7 +3867,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { dev_ctx->device = i; dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); - ggml_cuda_set_device(i); cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 0000000000000..a8b15ad72a916 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 0000000000000..1da18105508ac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 0000000000000..bc65c723eca9d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 0000000000000..10b330fa6c031 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 0000000000000..254b7d2e1dc29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 0000000000000..5caffac0467d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 0000000000000..90abb3b186261 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 0000000000000..7292c0aab8f98 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d410080fab841..81a986f38cacf 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,8 +3,17 @@ from glob import glob import os +HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576] + TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v}); +""" + SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec.cuh" @@ -51,6 +60,11 @@ def get_short_name(long_quant_name): for filename in glob("*.cu"): os.remove(filename) +for head_size_kq in HEAD_SIZES_KQ: + head_size_v = head_size_kq if head_size_kq != 576 else 512 + with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: + f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) + for type_k in TYPES_KV: for type_v in TYPES_KV: with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: @@ -64,7 +78,9 @@ def get_short_name(long_quant_name): with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + for head_size_kq in HEAD_SIZES_KQ: + if head_size_kq == 40: + continue if head_size_kq != 576 and ncols2 == 16: continue if head_size_kq == 576 and ncols2 != 16: diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 0e2b1847e09e2..934aefdcb45fa 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 1137e210773af..5f9370449bb2d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 45d91def88bf2..ddc285042d284 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7517,6 +7516,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7524,24 +7526,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7731,6 +7736,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7738,20 +7746,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index f8477a2ef356d..d76cb51977f90 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5fe5b749c355..36d495d6cfeab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 5026edcebed37..d0c44534efcb6 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 60326e8e50efe..cf12805b4998a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5401,15 +5401,6 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); - // TODO: implement - //int top_n = 1; - //if (body.count("top_n") != 1) { - // top_n = body.at("top_n"); - //} else { - // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - // return; - //} - // if true, use TEI API format, otherwise use Jina API format // Jina: https://jina.ai/reranker/ // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank @@ -5434,6 +5425,8 @@ int main(int argc, char ** argv) { return; } + int top_n = json_value(body, "top_n", (int)documents.size()); + // create and queue the task json responses = json::array(); bool error = false; @@ -5474,7 +5467,8 @@ int main(int argc, char ** argv) { body, responses, is_tei_format, - documents); + documents, + top_n); res_ok(res, root); }; diff --git a/tools/server/tests/unit/test_rerank.py b/tools/server/tests/unit/test_rerank.py index 0b63c7821eb98..ded8267109682 100644 --- a/tools/server/tests/unit/test_rerank.py +++ b/tools/server/tests/unit/test_rerank.py @@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens): assert res.status_code == 200 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == n_tokens + + +@pytest.mark.parametrize("top_n,expected_len", [ + (None, len(TEST_DOCUMENTS)), # no top_n parameter + (2, 2), + (4, 4), + (99, len(TEST_DOCUMENTS)), # higher than available docs +]) +def test_rerank_top_n(top_n, expected_len): + global server + server.start() + data = { + "query": "Machine learning is", + "documents": TEST_DOCUMENTS, + } + if top_n is not None: + data["top_n"] = top_n + + res = server.make_request("POST", "/rerank", data=data) + assert res.status_code == 200 + assert len(res.body["results"]) == expected_len + + +@pytest.mark.parametrize("top_n,expected_len", [ + (None, len(TEST_DOCUMENTS)), # no top_n parameter + (2, 2), + (4, 4), + (99, len(TEST_DOCUMENTS)), # higher than available docs +]) +def test_rerank_tei_top_n(top_n, expected_len): + global server + server.start() + data = { + "query": "Machine learning is", + "texts": TEST_DOCUMENTS, + } + if top_n is not None: + data["top_n"] = top_n + + res = server.make_request("POST", "/rerank", data=data) + assert res.status_code == 200 + assert len(res.body) == expected_len diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f175115f4fd6a..fd0bc8de533cf 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -849,47 +849,44 @@ static json format_response_rerank( const json & request, const json & ranks, bool is_tei_format, - std::vector & texts) { - json res; - if (is_tei_format) { - // TEI response format - res = json::array(); - bool return_text = json_value(request, "return_text", false); - for (const auto & rank : ranks) { - int index = json_value(rank, "index", 0); - json elem = json{ - {"index", index}, - {"score", json_value(rank, "score", 0.0)}, - }; - if (return_text) { - elem["text"] = std::move(texts[index]); - } - res.push_back(elem); - } - } else { - // Jina response format - json results = json::array(); - int32_t n_tokens = 0; - for (const auto & rank : ranks) { - results.push_back(json{ - {"index", json_value(rank, "index", 0)}, - {"relevance_score", json_value(rank, "score", 0.0)}, - }); - - n_tokens += json_value(rank, "tokens_evaluated", 0); - } - - res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{ - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"results", results} + std::vector & texts, + int top_n) { + int32_t n_tokens = 0; + bool return_text = is_tei_format && json_value(request, "return_text", false); + std::vector elements; // Temporary vector to hold unsorted elements + std::string score_label = is_tei_format ? "score" : "relevance_score"; + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {score_label, json_value(rank, "score", 0.0)}, }; + n_tokens += json_value(rank, "tokens_evaluated", 0); + if (return_text) { + elem["text"] = std::move(texts[index]); + } + elements.push_back(elem); } + std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) { + return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0); + }); + + elements.resize(std::min(top_n, (int)elements.size())); + json results = elements; + + if (is_tei_format) return results; + + json res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; + return res; } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte index 66369b2f1ce10..cc2631b830c3e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte @@ -2,8 +2,9 @@ import { Check, X } from '@lucide/svelte'; import { Card } from '$lib/components/ui/card'; import { Button } from '$lib/components/ui/button'; - import { ChatAttachmentsList } from '$lib/components/app'; + import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; + import { config } from '$lib/stores/settings.svelte'; import ChatMessageActions from './ChatMessageActions.svelte'; interface Props { @@ -55,6 +56,7 @@ let isMultiline = $state(false); let messageElement: HTMLElement | undefined = $state(); + const currentConfig = config(); $effect(() => { if (!messageElement || !message.content.trim()) return; @@ -123,9 +125,18 @@ class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5" data-multiline={isMultiline ? '' : undefined} > - - {message.content} - + {#if currentConfig.renderUserContentAsMarkdown} +
+ +
+ {:else} + + {message.content} + + {/if} {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte index 7c25e59258829..dc617afdcd4cd 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte @@ -80,6 +80,11 @@ key: 'showModelInfo', label: 'Show model information', type: 'checkbox' + }, + { + key: 'renderUserContentAsMarkdown', + label: 'Render user content as Markdown', + type: 'checkbox' } ] }, diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 63e4364ae5656..154ec888ce2dc 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -12,6 +12,7 @@ export const SETTING_CONFIG_DEFAULT: Record = pasteLongTextToFileLen: 2500, pdfAsImage: false, showModelInfo: false, + renderUserContentAsMarkdown: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', temperature: 0.8, @@ -84,6 +85,7 @@ export const SETTING_CONFIG_INFO: Record = { 'Ask for confirmation before automatically changing conversation title when editing the first message.', pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', showModelInfo: 'Display the model name used to generate each message below the message content.', + renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', pyInterpreterEnabled: 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.' };