From 0c3a3ffee6d5d783082bc6703c3bb17bfa24ceb4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 24 Aug 2025 14:57:21 +0300 Subject: [PATCH 01/23] WIP: adding mainline mmq_id implementation --- ggml/src/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/mmq_id.cu | 4320 ++++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmq_id.cuh | 7 + ggml/src/ggml-cuda/quantize_id.cu | 132 + ggml/src/ggml-cuda/quantize_id.cuh | 16 + 5 files changed, 4476 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-cuda/mmq_id.cu create mode 100644 ggml/src/ggml-cuda/mmq_id.cuh create mode 100644 ggml/src/ggml-cuda/quantize_id.cu create mode 100644 ggml/src/ggml-cuda/quantize_id.cuh diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 5bc773a95..98cba1081 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2316,7 +2316,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); + //CUDA_CHECK(cudaStreamSynchronize(stream)); return is_ser; } diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu new file mode 100644 index 000000000..cb5d7f31e --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -0,0 +1,4320 @@ +#include "mmq_id.cuh" +#include "quantize_id.cuh" + +#include "vecdotq.cuh" +#include "mma_new.cuh" + +#include +#include +#include + +using namespace ggml_cuda_mma; + +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 + +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); + +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + +struct block_q8_1_mmq { + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); + +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_MXFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ABORT("fatal error"); + break; + } +} + +struct tile_x_sizes { + int qs; + int dm; + int sc; +}; + +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 + +// AMD +// GCN/CDNA, wave size is 64 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 + +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 + +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) + +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD + +#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) + +#ifdef __CUDACC__ +template +__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} +#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) +#else +#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) +#endif // __CUDACC__ + +static bool amd_mfma_available(const int cc) { +#if !defined(GGML_HIP_NO_MMQ_MFMA) + return GGML_CUDA_CC_IS_CDNA(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_MMQ_MFMA) +} +static bool turing_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static int get_mmq_x_max_host(const int cc) { + return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_VOLTA ? +#ifdef GGML_CUDA_FORCE_MMQ + 128 : 64; +#else + MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static constexpr int ggml_cuda_get_physical_warp_size_host() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { +#if CUDART_VERSION >= 12080 + const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); + return (float) e; +#else + uint32_t bits; + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +#endif // CUDART_VERSION >= 12050 +} +template + +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; + } +} +template +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +} +template +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); + } + return a; +} +template +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#ifdef FP16_AVAILABLE +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + +static bool fp16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); +} + + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + return 128; +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#if defined(GGML_USE_HIP) + return 64; +#else // defined(GGML_USE_HIP) + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return 128; +#else // GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + +#endif // defined(GGML_USE_HIP) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +static int get_mmq_y_host(const int cc) { + return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); +} + +static constexpr __device__ int get_mmq_y_device() { +#if defined(GGML_USE_HIP) +#if defined(RDNA1) + return 64; +#else + return 128; +#endif // defined RDNA1 +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 128; +#else + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +} + +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } +} + +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } +} + +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (turing_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } +} + +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(TURING_MMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { + return 8; +} +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; +} +#endif // (GGML_USE_HIP) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 256/ggml_cuda_get_physical_warp_size(); +#endif // AMD_MFMA_AVAILABLE +} + +// ------------------------------------------------------------ + +template static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b4(bxi->qs, kqsx); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI_MXFP4; + const int kqsx = txi % QI_MXFP4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = kbx * (2 * QI_MXFP4) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#else + x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + tile_B B; + float dB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B; + float2 dsB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +// Used for Q3_K, IQ2_S, and IQ2_XS +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +// Used for Q3_K, IQ2_S, and IQ2_XS: +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum), GGML_UNUSED(k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + +#pragma unroll + for (int l = 0; l < QR2_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int sc_m = bxi->scales[kqsx]; +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; +#else + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + float2 y_df[mmq_x/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { + const int k0 = k00 + k01; + + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B[2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + + tile_C Cm[2]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd[2]; + + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { + float2 sB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); + +#pragma unroll + for (int l = 0; l < QR3_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int ksc = threadIdx.x % 4; + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; + } +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, txi); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*txi; + + const int ql = get_int_b4(bxi->qs, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + + const int ql = get_int_b2(bxi->ql, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; + + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { + const int k0 = k00 + k01; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; + } + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_iq4_nl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b2(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#else + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 8; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const float d = __half2float(bxi->d); + + const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) + | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void mmq_write_back_dp4a( + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +template +static __device__ __forceinline__ void mmq_write_back_mma( + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; + constexpr int rows_per_warp = 2 * granularity; +#endif // defined(AMD_MFMA_AVAILABLE) + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); + + if (j > j_max) { + continue; + } + + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + } + } + } +} + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template +struct mmq_type_traits; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +static __device__ __forceinline__ void mul_mat_q_process_tile( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); + + __syncthreads(); + } + + if (fixup) { + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + +template +#if defined(GGML_USE_HIP) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) +#else + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +static __global__ void mul_mat_q( + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { + NO_DEVICE_CODE; + return; + } + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + return; + } +#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + + const int64_t blocks_per_ne00 = ncols_x / qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); +} + + +template +static __global__ void mul_mat_q_stream_k_fixup( + const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, + const int ncols_max) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; + + const int bidx0 = blockIdx.x; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + bool any_fixup = false; + + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + any_fixup = true; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + } + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; + } + + if (!any_fixup) { + return; + } + + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_dst - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { + ids_dst_shared[j] = ids_dst[col_low + j]; + } + __syncthreads(); + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +struct mmq_args { + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; + bool use_stream_k; int64_t ncols_max; +}; + +template +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); +} + +template +static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + const int mmq_y = get_mmq_y_host(cc); + + const dim3 block_dims(warp_size, nwarps, 1); + + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); + + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; + + if (!args.use_stream_k) { + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } + return; + } + + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } + + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } +} + +template +void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc); + + int mmq_x_best = 0; + int ntiles_x_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { + continue; + } + + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; + + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q(ctx, args, stream); + break; + case 16: + launch_mul_mat_q(ctx, args, stream); + break; + case 24: + launch_mul_mat_q(ctx, args, stream); + break; + case 32: + launch_mul_mat_q(ctx, args, stream); + break; + case 40: + launch_mul_mat_q(ctx, args, stream); + break; + case 48: + launch_mul_mat_q(ctx, args, stream); + break; + case 56: + launch_mul_mat_q(ctx, args, stream); + break; + case 64: + launch_mul_mat_q(ctx, args, stream); + break; + case 72: + launch_mul_mat_q(ctx, args, stream); + break; + case 80: + launch_mul_mat_q(ctx, args, stream); + break; + case 88: + launch_mul_mat_q(ctx, args, stream); + break; + case 96: + launch_mul_mat_q(ctx, args, stream); + break; + case 104: + launch_mul_mat_q(ctx, args, stream); + break; + case 112: + launch_mul_mat_q(ctx, args, stream); + break; + case 120: + launch_mul_mat_q(ctx, args, stream); + break; + case 128: + launch_mul_mat_q(ctx, args, stream); + break; + default: + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); + GGML_ABORT("fatal error"); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); + +// ------------------------------------------------------------------------------------------------------------------------- + +static bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mmq_ids_helper_store { + uint32_t data; + + __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mmq_ids_helper[]; + mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mmq_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= offset) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mmq_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < gridDim.x - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); + mmq_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + switch (args.type_x) { + case GGML_TYPE_Q4_0: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_q_case(ctx, args, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ids_tensor->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(ids_tensor->nb[0] == ggml_type_size(ids_tensor->type)); + + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const char * src0_d = (const char *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); + + const int64_t n_expert_used = ids_tensor->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + GGML_ASSERT(ne1 == n_expert_used); + + ggml_cuda_pool_alloc ids_src1_local(ctx.pool()); + ggml_cuda_pool_alloc ids_dst_local(ctx.pool()); + ggml_cuda_pool_alloc expert_bounds_local(ctx.pool()); + + int32_t * ids_src1, *ids_dst, *expert_bounds; + if (ids_data) { + ids_src1 = (int32_t *)ids_data; + ids_dst = ids_src1 + ne_get_rows; + expert_bounds = ids_dst + ne_get_rows; + } + else { + GGML_ASSERT(ids_tensor->nb[0] == ggml_element_size(ids_tensor)); + + ids_src1_local.alloc(ne_get_rows); + ids_dst_local.alloc(ne_get_rows); + expert_bounds_local.alloc(ne02 + 1); + + ids_src1 = ids_src1_local.get(); + ids_dst = ids_dst_local.get(); + expert_bounds = expert_bounds_local.get(); + + const int si1 = ids_tensor->nb[1] / ggml_element_size(ids_tensor); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + const size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + + ggml_cuda_pool_alloc src1_q8_1_local(ctx.pool()); + + char * src1_q8_1; + + if (src1_quantized_data) { + src1_q8_1 = src1_quantized_data; + } else { + + src1_q8_1_local.alloc(nbytes_src1_q8_1); + src1_q8_1 = src1_q8_1_local.get(); + + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[2] / ts_src1; + quantize_mmq_q8_1_cuda_id(src1_d, ids_src1, src1_q8_1, src0->type, + ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1, ids_dst, expert_bounds, dst_d, + ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, + ne02, ne02, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k, ne12}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); +} + +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + mmq_supported = true; + break; + default: + mmq_supported = false; + break; + } + + if (!mmq_supported) { + return false; + } + + if (turing_mma_available(cc)) { + return true; + } + + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; +} diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh new file mode 100644 index 000000000..bc5d7c616 --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_mul_mat_q_id( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, + ggml_tensor * dst, char * ids_data, char * src1_quantized_data); diff --git a/ggml/src/ggml-cuda/quantize_id.cu b/ggml/src/ggml-cuda/quantize_id.cu new file mode 100644 index 000000000..80a9f2220 --- /dev/null +++ b/ggml/src/ggml-cuda/quantize_id.cu @@ -0,0 +1,132 @@ +#include "quantize_id.cuh" +#include "mmq.cuh" +#include + +template +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { + + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; + + if (i0 >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i00 = i0; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + const float4 * x4 = (const float4 *) x; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel + const int64_t iqs = i0 % (4*QK8_1); // quant index in block + + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); + + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); + } + + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; + + // Calculate sums across vals_per_sum/4 threads. +#pragma unroll + for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); + } else { + y[ib].d4[iqs/32] = d; + } +} + +void quantize_mmq_q8_1_cuda_id( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne0 % (4*QK8_1) == 0); + GGML_ASSERT(ids); + + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_src0)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/quantize_id.cuh b/ggml/src/ggml-cuda/quantize_id.cuh new file mode 100644 index 000000000..d37539236 --- /dev/null +++ b/ggml/src/ggml-cuda/quantize_id.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "common.cuh" + +#include + +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +//static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +//static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); + +void quantize_mmq_q8_1_cuda_id( + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); From af6b5365ccda18185db269de99514769448fd0a3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 24 Aug 2025 16:46:41 +0300 Subject: [PATCH 02/23] This seems to work --- ggml/src/ggml-cuda.cu | 5 + ggml/src/ggml-cuda/mmq_id.cu | 249 ++++++++++++++++++++--------------- 2 files changed, 150 insertions(+), 104 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 98cba1081..93fa4ff4b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -39,6 +39,7 @@ #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/add-id.cuh" #include "ggml-cuda/graph.cuh" +#include "ggml-cuda/mmq_id.cuh" #include #include @@ -2392,6 +2393,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); + return false; + + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index cb5d7f31e..7c0a76fc7 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -134,6 +134,40 @@ struct tile_x_sizes { #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11010 + +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define FP16_MMA_AVAILABLE +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) + +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#define TURING_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #ifdef __CUDACC__ template __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} @@ -2961,7 +2995,7 @@ template static __device__ __forceinline__ void loa } template -static __device__ __forceinline__ void mmq_write_back_dp4a( +static __device__ __forceinline__ void mmq_write_back_dp4a_id( const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -2989,7 +3023,7 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( } template -static __device__ __forceinline__ void mmq_write_back_mma( +static __device__ __forceinline__ void mmq_write_back_mma_id( const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { @@ -3040,10 +3074,10 @@ static __device__ __forceinline__ void mmq_write_back_mma( // ------------------------------------------------------------------------------------------------------------------------------------- template -struct mmq_type_traits; +struct mmq_type_traits_id; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3051,7 +3085,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; @@ -3059,7 +3093,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3067,7 +3101,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; @@ -3075,7 +3109,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3083,7 +3117,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3091,7 +3125,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; @@ -3099,7 +3133,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; @@ -3107,7 +3141,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; @@ -3115,7 +3149,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; @@ -3123,7 +3157,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; @@ -3131,7 +3165,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3139,7 +3173,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; @@ -3147,7 +3181,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; @@ -3155,7 +3189,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3163,7 +3197,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3171,7 +3205,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; @@ -3179,7 +3213,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3187,7 +3221,7 @@ struct mmq_type_traits { }; template -struct mmq_type_traits { +struct mmq_type_traits_id { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3195,7 +3229,7 @@ struct mmq_type_traits { }; template -static __device__ __forceinline__ void mul_mat_q_process_tile( +static __device__ __forceinline__ void mul_mat_q_process_tile_id( const char * __restrict__ x, const int offset_x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, @@ -3205,18 +3239,18 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( constexpr int nwarps = mmq_get_nwarps_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits_id::load_tiles; extern __shared__ int data_mul_mat_q[]; int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; - constexpr mmq_write_back_t write_back = mmq_write_back_mma; + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma_id; #else - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; - constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a_id; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3267,7 +3301,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( } -// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 +// The mul_mat_q_id kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 template #if defined(GGML_USE_HIP) @@ -3281,7 +3315,7 @@ template __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -static __global__ void mul_mat_q( +static __global__ void mul_mat_q_id( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, @@ -3370,7 +3404,7 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; - mul_mat_q_process_tile + mul_mat_q_process_tile_id (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; @@ -3448,7 +3482,7 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. - mul_mat_q_process_tile + mul_mat_q_process_tile_id (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); @@ -3515,14 +3549,14 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. - mul_mat_q_process_tile + mul_mat_q_process_tile_id (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } template -static __global__ void mul_mat_q_stream_k_fixup( +static __global__ void mul_mat_q_stream_k_fixup_id( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, @@ -3673,7 +3707,7 @@ static __global__ void mul_mat_q_stream_k_fixup( } } -struct mmq_args { +struct mmq_args_id { const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; @@ -3692,7 +3726,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int } template -static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { +static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; @@ -3704,14 +3738,17 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; const int ntzw = args.nchannels_y * args.nsamples_y; const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + if (args.nchannels_y % args.nchannels_x) { + printf("Oops: args.nchannels_y = %d, args.nchannels_x = %d\n", args.nchannels_y, args.nchannels_x); + } GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); const int channel_ratio = args.nchannels_y / args.nchannels_x; @@ -3720,7 +3757,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + mul_mat_q_id<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3728,7 +3765,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a args.ncols_max); } else { constexpr bool need_check = true; - mul_mat_q<<>> + mul_mat_q_id<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3749,7 +3786,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + mul_mat_q_id<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3760,13 +3797,13 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup_id<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, args.ncols_max); } else { constexpr bool need_check = true; - mul_mat_q<<>> + mul_mat_q_id<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3777,7 +3814,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup_id<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, args.ncols_max); @@ -3785,7 +3822,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a } template -void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { +void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const size_t smpbo = ggml_cuda_info().devices[id].smpbo; @@ -3815,52 +3852,52 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda switch (mmq_x_best) { case 8: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 16: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 24: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 32: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 40: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 48: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 56: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 64: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 72: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 80: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 88: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 96: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 104: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 112: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 120: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; case 128: - launch_mul_mat_q(ctx, args, stream); + launch_mul_mat_q_id(ctx, args, stream); break; default: fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); @@ -3870,27 +3907,27 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda } #define DECL_MMQ_CASE(type) \ - template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ - -extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); -extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); -extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); -extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); -extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); -extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); -extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); -extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); -extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); -extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); -extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); -extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); -extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); -extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); -extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); -extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); -extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); -extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); -extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); + template void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) \ + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); +DECL_MMQ_CASE(GGML_TYPE_Q4_1); +DECL_MMQ_CASE(GGML_TYPE_Q5_0); +DECL_MMQ_CASE(GGML_TYPE_Q5_1); +DECL_MMQ_CASE(GGML_TYPE_Q8_0); +DECL_MMQ_CASE(GGML_TYPE_MXFP4); +DECL_MMQ_CASE(GGML_TYPE_Q2_K); +DECL_MMQ_CASE(GGML_TYPE_Q3_K); +DECL_MMQ_CASE(GGML_TYPE_Q4_K); +DECL_MMQ_CASE(GGML_TYPE_Q5_K); +DECL_MMQ_CASE(GGML_TYPE_Q6_K); +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // ------------------------------------------------------------------------------------------------------------------------- @@ -4029,64 +4066,64 @@ static void launch_mmq_ids_helper( (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); } -static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { +static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q4_0: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q4_1: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q5_0: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q5_1: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q8_0: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_MXFP4: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q2_K: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q3_K: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q4_K: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q5_K: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_Q6_K: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ2_S: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ3_S: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ1_S: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; case GGML_TYPE_IQ4_NL: - mul_mat_q_case(ctx, args, stream); + mul_mat_q_case_id(ctx, args, stream); break; default: GGML_ABORT("fatal error"); @@ -4236,14 +4273,18 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. - const mmq_args args = { + const mmq_args_id args = { src0_d, src0->type, (const int *) src1_q8_1, ids_dst, expert_bounds, dst_d, ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3, use_stream_k, ne12}; - ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); + //printf("ne00 = %ld, ne01 = %ld, ne_get_rows = %ld, s01 = %ld, s1 = %ld\n", ne00, ne01, ne_get_rows, s01, s1); + //printf("ne02 = %ld, s02 = %ld, s12 = %ld, s2 = %ld\n", ne02, s02, s12, s2); + //printf("ne03 = %ld, s03 = %ld, s13 = %ld, s3 = %ld\n", ne03, s03, s13, s3); + + ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream); } bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { From 406da6670de0c023965315e24fe4833758af87fe Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 24 Aug 2025 18:59:18 +0300 Subject: [PATCH 03/23] Now also -fmoe works --- ggml/src/ggml-cuda.cu | 396 +++++++++++++++++----------------- ggml/src/ggml-cuda/mmq_id.cu | 40 ++++ ggml/src/ggml-cuda/mmq_id.cuh | 4 + 3 files changed, 242 insertions(+), 198 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 93fa4ff4b..c234ec07b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -40,6 +40,7 @@ #include "ggml-cuda/add-id.cuh" #include "ggml-cuda/graph.cuh" #include "ggml-cuda/mmq_id.cuh" +#include "ggml-cuda/quantize_id.cuh" #include #include @@ -2393,6 +2394,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + //printf("src0(%s): %ld x %ld x %ld, src1: %ld x %ld x %ld dst: ids: %ld x %ld x %ld, %ld x %ld x %ld\n", + // src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], + // ids->ne[0], ids->ne[1], ids->ne[2], dst->ne[0], dst->ne[1], dst->ne[2]); + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); return false; @@ -2667,28 +2672,75 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } } - GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + GGML_TENSOR_BINARY_OP_LOCALS + cudaStream_t stream = ctx.stream(); + ggml_tensor src0_1_row = *src0_1; + ggml_tensor src0_2_row = *src0_2; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + ggml_tensor final_dst; + ggml_tensor final_src; + const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (src1->ne[2] <= 2048 && + ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && + ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + + const int64_t ne_get_rows = ne12 * n_ids; + ggml_cuda_pool_alloc ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1); + auto ids_src1 = ids_device.get(); + auto ids_dst = ids_src1 + ne_get_rows; + auto expert_bounds = ids_dst + ne_get_rows; + + compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream); + + const int64_t ne11_flat = ne12*n_ids; + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_quantized(ctx.pool(), nbytes_src1_q8_1); + + size_t ts_src1 = ggml_type_size(src1->type); + quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(), + src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, + ne10_padded, ne11_flat, 1, 1, stream); + + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr); + ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr); + return true; + } + + return false; + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_1_row = *src0_1; - ggml_tensor src0_2_row = *src0_2; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; - ggml_tensor final_dst; - ggml_tensor final_src; - char * src0_1_original = (char *) src0_1->data; char * src0_2_original = (char *) src0_2->data; char * src1_original = (char *) src1->data; @@ -2728,222 +2780,170 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (false && ne12 == 1) { - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - if (fuse_down) { - final_dst.src[1] = &dst_row; + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + bool use_quantized_src1 = false; + int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; + if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); + src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + src1_quantized.alloc(src1_quantized_size); + use_quantized_src1 = true; } - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); - - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); - - const int64_t i11 = id % ne11; - const int64_t i12 = 0; - - const int64_t i1 = id; - const int64_t i2 = i12; - - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - //dst_row.data = dst_original + i1*nb1 + i2*nb2; - - dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); - - dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); - - if (fuse_down) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); - CUDA_CHECK(cudaGetLastError()); + } + ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); + if (!use_quantized_src1) { + src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); + } + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_contiguous.get(); - } else { + bool first = false; //true; - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); - CUDA_CHECK(cudaGetLastError()); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; - } - } - } else { - //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]); - ggml_cuda_pool_alloc src1_quantized(ctx.pool()); - bool use_quantized_src1 = false; - int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; - if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { - if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { - src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); - src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); - src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); - src1_quantized.alloc(src1_quantized_size); - use_quantized_src1 = true; - } - } - ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); - if (!use_quantized_src1) { - src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); - } - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + if (is_ser) { if (fuse_down) { - final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); - final_dst.src[1] = &dst_row; + CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); + } else { + CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); } + } - src1_row.data = src1_contiguous.get(); - - bool first = false; //true; + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = moe_counts[i02]; - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); - std::vector moe_counts, cum_moe_counts; + if (num_src1_rows == 0) continue; + size_t mapping_offset = cum_moe_counts[i02]; - bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); - if (is_ser) { - if (fuse_down) { - CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); - } else { - CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - } + if (use_quantized_src1) { + quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), + src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_quantized.get(); + } + else { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_contiguous.get(); } - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = moe_counts[i02]; - - if (num_src1_rows == 0) continue; - size_t mapping_offset = cum_moe_counts[i02]; - - if (use_quantized_src1) { - quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), - src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_quantized.get(); - } - else { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_src_to_contiguous<<>>( - src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_contiguous.get(); - } + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; + src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; + src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; - src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; - src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + dst_row.data = dst_up_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + } + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_up_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - } + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); + } - if (dst->src[4]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } + dst_row.data = dst_gate_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + } + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_gate_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - } + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); + } - if (dst->src[5]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } + CUDA_CHECK(cudaGetLastError()); - auto unary_op = (ggml_unary_op)dst->op_params[0]; - if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { - ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], - 1.702f, 7.0f, stream); - } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); + if (fuse_down) { + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); CUDA_CHECK(cudaGetLastError()); - if (fuse_down) { - - final_dst.ne[1] = num_src1_rows; - final_dst.nb[1] = final_dst.ne[0]*sizeof(float); - final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - if (first) { - printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, - (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], - (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], - (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); - printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", - (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], - (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], - (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); - first = false; - } - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); - - dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - (char *)next->data, final_dst_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - next->ne[0], - next->nb[1], next->nb[2]); - CUDA_CHECK(cudaGetLastError()); - - } - else { + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_gate_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } + } + else { + + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); } } diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 7c0a76fc7..78282997f 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -4131,6 +4131,46 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, } } +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, + cudaStream_t stream) { + + const int si1 = nb21 / sizeof(int); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); +} + void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) { GGML_ASSERT( src1->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh index bc5d7c616..c85c468fe 100644 --- a/ggml/src/ggml-cuda/mmq_id.cuh +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -5,3 +5,7 @@ void ggml_cuda_mul_mat_q_id( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, char * ids_data, char * src1_quantized_data); + +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); + From 916733144c7c441145965bc714c7c11b3205cd8c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 24 Aug 2025 19:37:02 +0300 Subject: [PATCH 04/23] WIP --- ggml/src/ggml-cuda/mmq_id.cu | 4008 +------------------------- ggml/src/ggml-cuda/mmq_id_common.cuh | 3932 +++++++++++++++++++++++++ ggml/src/ggml-cuda/mmq_id_kernels.cu | 22 + 3 files changed, 3958 insertions(+), 4004 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmq_id_common.cuh create mode 100644 ggml/src/ggml-cuda/mmq_id_kernels.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 78282997f..17e6798bf 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -1,3937 +1,10 @@ +#include "mmq_id_common.cuh" #include "mmq_id.cuh" #include "quantize_id.cuh" -#include "vecdotq.cuh" -#include "mma_new.cuh" - -#include -#include -#include - -using namespace ggml_cuda_mma; - -#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -#define MMQ_ITER_K 256 -#define MMQ_NWARPS 8 - -typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); -typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); -typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, - float * __restrict__ dst, const int stride, const int i_max, const int j_max); - -enum mmq_q8_1_ds_layout { - MMQ_Q8_1_DS_LAYOUT_D4, - MMQ_Q8_1_DS_LAYOUT_DS4, - MMQ_Q8_1_DS_LAYOUT_D2S6, -}; - -struct block_q8_1_mmq { - // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. - // The y float data is first grouped as blocks of 128 values. - // These blocks are then treated as individual data values and transposed. - // - // To avoid shared memory bank conflicts each block is padded with 16 bytes. - // This padding is also used to store block scales/partial sums. - // The scales multiplied with the quantized data are equal to the unquantized values. - // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) - // and are only needed for performance reasons. - // - // The exact data stored depends on the x data type. - union { - float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 - half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 - half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, - // stored as d0,d1,s1,s2,s3,s4,s5 - }; - int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each -}; -static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); -static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); - -static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { - switch (type_x) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return MMQ_Q8_1_DS_LAYOUT_DS4; - case GGML_TYPE_Q5_0: - return MMQ_Q8_1_DS_LAYOUT_D4; - case GGML_TYPE_Q5_1: - return MMQ_Q8_1_DS_LAYOUT_DS4; - case GGML_TYPE_Q8_0: - return MMQ_Q8_1_DS_LAYOUT_D4; - case GGML_TYPE_MXFP4: - return MMQ_Q8_1_DS_LAYOUT_D4; - case GGML_TYPE_Q2_K: - return MMQ_Q8_1_DS_LAYOUT_D2S6; - case GGML_TYPE_Q3_K: - return MMQ_Q8_1_DS_LAYOUT_D4; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return MMQ_Q8_1_DS_LAYOUT_DS4; - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - return MMQ_Q8_1_DS_LAYOUT_D4; - case GGML_TYPE_IQ1_S: - return MMQ_Q8_1_DS_LAYOUT_DS4; - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - return MMQ_Q8_1_DS_LAYOUT_D4; - default: - GGML_ABORT("fatal error"); - break; - } -} - -struct tile_x_sizes { - int qs; - int dm; - int sc; -}; - -#define GGML_CUDA_CC_PASCAL 600 -#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define GGML_CUDA_CC_VOLTA 700 -#define GGML_CUDA_CC_TURING 750 -#define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_ADA_LOVELACE 890 -#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 -#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 - -// AMD -// GCN/CDNA, wave size is 64 -#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 -#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue -#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing -#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 - -// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 -#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 -#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a -#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA -#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 - -#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) -#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) -#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) -#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) -#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) - -#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 -#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD - -#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) -#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) -#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) - -#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) -#define GGML_CUDA_ASSUME(x) __builtin_assume(x) -#else -#define GGML_CUDA_ASSUME(x) -#endif // CUDART_VERSION >= 11010 - -#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) -#define GGML_USE_VMM -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) - -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) -#define FP16_MMA_AVAILABLE -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) - -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) - -#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) -#define AMD_MFMA_AVAILABLE -#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) - -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define TURING_MMA_AVAILABLE -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#define AMPERE_MMA_AVAILABLE -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#define CP_ASYNC_AVAILABLE -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - -#ifdef __CUDACC__ -template -__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} -#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) -#else -#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) -#endif // __CUDACC__ - -static bool amd_mfma_available(const int cc) { -#if !defined(GGML_HIP_NO_MMQ_MFMA) - return GGML_CUDA_CC_IS_CDNA(cc); -#else - return false; -#endif //!defined(GGML_HIP_NO_MMQ_MFMA) -} -static bool turing_mma_available(const int cc) { - return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; -} - -static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : - GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_VOLTA ? -#ifdef GGML_CUDA_FORCE_MMQ - 128 : 64; -#else - MMQ_DP4A_MAX_BATCH_SIZE : 64; -#endif // GGML_CUDA_FORCE_MMQ -} -static constexpr __device__ int ggml_cuda_get_physical_warp_size() { -#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) - return 64; -#else - return 32; -#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) -} -static constexpr int ggml_cuda_get_physical_warp_size_host() { -#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) - return 64; -#else - return 32; -#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) -} -static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { -#if CUDART_VERSION >= 12080 - const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); - return (float) e; -#else - uint32_t bits; - if (x == 0) { - bits = 0x00400000; - } else { - bits = (uint32_t) x << 23; - } - - float result; - memcpy(&result, &bits, sizeof(float)); - return result; -#endif // CUDART_VERSION >= 12050 -} -template - -static __device__ __forceinline__ int warp_reduce_any(int x) { - if (width == ggml_cuda_get_physical_warp_size()) { - return __any_sync(0xffffffff, x); - } else { -#pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; - } - return x; - } -} -template -static __device__ __forceinline__ int warp_reduce_sum(int x) { -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - return __reduce_add_sync(0xffffffff, x); -#else -#pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, width); - } - return x; -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -} -template -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { -#pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); - a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); - } - return a; -} -template -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#ifdef FP16_AVAILABLE -#pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); - } - return a; - -#else - NO_DEVICE_CODE; - return a; -#endif // FP16_AVAILABLE -} - -static bool fp16_mma_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || - (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); -} - - -static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - -#if defined(GGML_USE_HIP) - return 64; -#else // defined(GGML_USE_HIP) - -#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#ifdef GGML_CUDA_FORCE_MMQ - return 128; -#else // GGML_CUDA_FORCE_MMQ - return MMQ_DP4A_MAX_BATCH_SIZE; -#endif // GGML_CUDA_FORCE_MMQ -#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - return 64; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - -#endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) -} - -static int get_mmq_y_host(const int cc) { - return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : - ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); -} - -static constexpr __device__ int get_mmq_y_device() { -#if defined(GGML_USE_HIP) -#if defined(RDNA1) - return 64; -#else - return 128; -#endif // defined RDNA1 -#else -#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - return 128; -#else - return 64; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#endif // defined(GGML_USE_HIP) -} - -// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. -// The K dimension of the tiles has either, -// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), -// 32 bit elements for the quantized data (does not include scales). -// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. -// The final tile size in K direction is padded to avoid shared memory bank conflicts, -// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. -#define MMQ_TILE_NE_K 32 - -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} - -static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - switch (type) { - case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; - case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; - case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; - case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; - case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; - case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; - case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; - case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; - case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; - case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; - case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; - case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; - case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; - default: return tile_x_sizes{0, 0, 0}; - } -} - -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) - -static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); - -static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; - case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; - case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; - case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; - case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; - case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; - case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; - case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; - case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; - case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; - case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; - case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; - default: return 0; - } -} - -// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) -#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) - -static int mmq_get_granularity_host(const int mmq_x, const int cc) { - if (amd_mfma_available(cc)) { - return mmq_x >= 128 ? 32 : 16; - } else if (turing_mma_available(cc) && mmq_x >= 48) { - return 16; - } else { - return 8; - } -} - -#if defined(AMD_MFMA_AVAILABLE) -static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { - return mmq_x >= 128 ? 32 : 16; -} -#elif defined(TURING_MMA_AVAILABLE) -static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { - return mmq_x >= 48 ? 16 : 8; -} -#else -static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { - return 8; -} -#endif // AMD_MFMA_AVAILABLE - -#if defined(GGML_USE_HIP) -static int mmq_get_nwarps_host(const int cc, const int warp_size) { - return amd_mfma_available(cc) ? 8 : 256/warp_size; -} -#else -static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { - return 256/warp_size; -} -#endif // (GGML_USE_HIP) - -static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(AMD_MFMA_AVAILABLE) - return 8; -#else - return 256/ggml_cuda_get_physical_warp_size(); -#endif // AMD_MFMA_AVAILABLE -} - -// ------------------------------------------------------------ - -template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI4_0; - const int kqsx = txi % QI4_0; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; - const int qs0 = get_int_b2(bxi->qs, kqsx); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); -#else - x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; -#else - x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - - int u[2*VDR_Q4_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; - } - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, - x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI4_1; - const int kqsx = txi % QI4_1; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; - const int qs0 = get_int_b4(bxi->qs, kqsx); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; -#else - x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; -#else - x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; - } - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, - x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI5_0; - const int kqsx = txi % QI5_0; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; - - const int ql = get_int_b2(bxi->qs, kqsx); - const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; -#else - x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI5_1; - const int kqsx = txi % QI5_1; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; - - const int ql = get_int_b4(bxi->qs, kqsx); - const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; -#else - x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp - constexpr int threads_per_row = 32; - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI8_0; - const int kqsx = txi % QI8_0; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; -#else - x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_mxfp4( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI_MXFP4; - const int kqsx = txi % QI_MXFP4; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; - - const int aux_q4 = get_int_b1(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); - const int k0 = kbx * (2 * QI_MXFP4) + kqsx; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#else - x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], - x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); - } - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - float dB; - const int j = j0 + tile_C::get_j(0); - if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { - dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } else { - dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_A::I + tile_C::get_i(l); - const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; - } - } - } - } -#else - typedef tile<16, 8, int> tile_A; - typedef tile< 8, 8, int> tile_B; - typedef tile<16, 8, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - const half2 * y_ds = (const half2 *) y; - - tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; - float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; - - const int i0 = (threadIdx.y/ntx)*rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); - } - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { - const int k0 = k00 + k01; - - dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { - tile_B B; - float dB[tile_C::ne/2]; - - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } else { - dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n][k01/QI8_0], B); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; - } - } - } - } -#endif // defined(AMD_MFMA_AVAILABLE) -} - -template -static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_dm = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_A::I + tile_C::get_i(l); - float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; - sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; - } - } - } - } -#else - typedef tile<16, 8, int> tile_A; - typedef tile< 8, 8, int> tile_B; - typedef tile<16, 8, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_dm = (const half2 *) y; - - tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; - float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; - - const int i0 = (threadIdx.y/ntx)*rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); - } - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - const int k0 = k00 + k01; - - dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - tile_B B; - float2 dsB[tile_C::ne/2]; - - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n][k01/QI8_1], B); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; - sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; - } - } - } - } -#endif // defined(AMD_MFMA_AVAILABLE) -} - -// Used for Q3_K, IQ2_S, and IQ2_XS -template -static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( - &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], - &y_qs[j*MMQ_TILE_Y_K + k01], - &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], - y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -// Used for Q3_K, IQ2_S, and IQ2_XS: -template -static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(TURING_MMA_AVAILABLE) - - typedef tile<16, 4, int> tile_A; - typedef tile<16, 8, int> tile_A_8; - typedef tile< 8, 4, int> tile_B; - typedef tile<16, 8, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - - tile_A A[ntx][8]; - float dA[ntx][tile_C::ne/2][8]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { - const int k0 = k00 + k01; - - load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - tile_B B[2]; - float dB[tile_C::ne/2]; - - // Here load_generic is faster than load_ldmatrix. - load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C[2]; - mma(C[0], A[n][k01/4 + 0], B[0]); - mma(C[1], A[n][k01/4 + 1], B[1]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum), GGML_UNUSED(k00); - NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE -} - -template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); - constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; - - const int x_ql_0 = get_int_b2(bxi->qs, kqsx); - -#pragma unroll - for (int l = 0; l < QR2_K; ++l) { - const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - - const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int sc_m = bxi->scales[kqsx]; -#ifdef FAST_FP16_AVAILABLE - const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); -#else - const float2 bxi_dmf = __half22float2(bxi->dm); - const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); -#endif // FAST_FP16_AVAILABLE - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; -#else - x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - float2 y_df[mmq_x/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); - } - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - constexpr int ns = 2; - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } - } - } - - // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. - // As a workaround 2 separate loops are used instead. -#pragma unroll - for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - constexpr int ns = 1; - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(TURING_MMA_AVAILABLE) - - typedef tile<16, 4, int> tile_A; - typedef tile<16, 8, int> tile_A_8; - typedef tile< 8, 4, int> tile_B; - typedef tile<16, 8, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - - tile_A A[ntx][8]; - float dA[ntx][tile_C::ne/2][8]; - float mA[ntx][tile_C::ne/2][8]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - const int k0 = k00 + k01; - - load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { - const int k0 = k00 + k01; - - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); - - dA[n][l][k01/(QI8_1/2)] = dm.x; - mA[n][l][k01/(QI8_1/2)] = dm.y; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - float2 dB[tile_C::ne/2]; - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); - } - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { - tile_B B[2]; - - // Here load_generic is faster than load_ldmatrix. - load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); - - tile_C Cm[2]; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm[0], A1, B[0]); - mma(Cm[1], A1, B[1]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd[2]; - - mma(Cd[0], A[n][k01/4 + 0], B[0]); - mma(Cd[1], A[n][k01/4 + 1], B[1]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); - } - } - } - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { - float2 sB[tile_C::ne/2]; - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; - } - } - } - } -#else - GGML_UNUSED_VARS(x, y, sum, k00); - NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE -} - -template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); - int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - - const int x_ql_0 = get_int_b2(bxi->qs, kqsx); - const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); - -#pragma unroll - for (int l = 0; l < QR3_K; ++l) { - const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - - const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; - const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; - - const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - } - - constexpr int rows_per_warp = warp_size / 4; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { - int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - - const int ksc = threadIdx.x % 4; - - const int ksc_low = ksc % (QI3_K/8); - const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; - - const int ksc_high = QI3_K/8; - const int shift_high = 2 * ksc; - const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; - - const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - const int8_t * sc8 = (const int8_t *) ≻ - const float d = bxi->d; - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; - } -#else - x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { - int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - - x_df[i] = bxi->d; - } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) -} - -template -static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * x_sc = (const int *) x_df + txs.dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, - x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { - // scale arrangement after the following two lines: - // - ksc == 0: sc0, sc1, sc2, sc3 - // - ksc == 1: sc4, sc5, sc6, sc7 - // - ksc == 2: m0, m1, m2, m3 - // - ksc == 3: m4, m5, m6, m7 - return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits - ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits -} - -template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + txs.qs); - int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const int qs0 = get_int_b4(bxi->qs, txi); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; -#else - x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - constexpr int rows_per_warp = warp_size / 2; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) - // Need if on AMD instead of % because warp_size == 64 - // This causes double work and throughput loss (MI300X) - // H100 loses about 100 t/s with 'if' condition over '%' - int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; - if (i < mmq_y) { -#else - int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; - { -#endif // defined(AMD_MFMA_AVAILABLE) - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % 2; - - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); - - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; - - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - - #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); - } - } - } -#else -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { - int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - - x_dm[i] = bxi->dm; - } - constexpr int rows_per_warp = warp_size / 4; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { - int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); - - const int * scales = (const int *) bxi->scales; - - const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); - const int scales8 = unpack_scales_q45_K(scales, ksc); - - x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; - } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) -} - -template -static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * x_sc = (const int *) x_dm + txs.dm; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + txs.qs); - int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - const int ky = QR5_K*txi; - - const int ql = get_int_b4(bxi->qs, txi); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; - - const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - constexpr int rows_per_warp = warp_size / 2; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) - // Need if on AMD instead of % because warp_size == 64 - // This causes double work and throughput loss (MI300X) - // H100 loses about 100 t/s with 'if' condition over '%' - int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; - if (i < mmq_y) { -#else - int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; - { -#endif // defined(AMD_MFMA_AVAILABLE) - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % 2; - - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); - - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; - - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); - } - } - } -#else -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { - int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - - x_dm[i] = bxi->dm; - } - - constexpr int rows_per_warp = warp_size / 4; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { - int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - - const int * scales = (const int *) bxi->scales; - - const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); - const int scales8 = unpack_scales_q45_K(scales, ksc); - - x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; - } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) -} - -template -static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * x_sc = (const int *) x_dm + txs.dm; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); - int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); - int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - - const int ql = get_int_b2(bxi->ql, txi); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); - const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; - const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; - - const int kq0 = 2*txi - txi % (QI6_K/2) + 0; - const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { - int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; -#else - x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int rows_per_warp = warp_size / 4; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { - int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); -#else - x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * x_sc = (const int *) x_df + txs.dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -// #pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); - - sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, - x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(TURING_MMA_AVAILABLE) - - typedef tile<16, 4, int> tile_A; - typedef tile< 8, 4, int> tile_B; - typedef tile<16, 8, int> tile_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - - tile_A A[ntx][8]; - int scA[ntx][tile_C::ne/2][8]; - float dA[ntx][tile_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { - const int k0 = k00 + k01; - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; - const int8_t * sc = (const int8_t *) &sc_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][k01/4 + ksc] = sc[ksc]; - } - } - } - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); - - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - float tmp[ntx][tile_C::ne] = {{0.0f}}; - -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { - tile_B B[2]; - float dB[tile_C::ne/2]; - - // Here load_generic is faster than load_ldmatrix. - load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < tile_C::ne/2; ++l) { - const int j = j0 + tile_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C[2]; - mma(C[0], A[n][k01/4 + 0], B[0]); - mma(C[1], A[n][k01/4 + 1], B[1]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; - } - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; - } - } - } -#else - GGML_UNUSED_VARS(x, y, sum, k00); - NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE -} - -template static __device__ __forceinline__ void load_tiles_iq4_nl( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); - constexpr int nrows = warp_size / threads_per_row; - const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - const int kbx = txi / QI4_NL; - const int kqsx = txi % QI4_NL; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; - - const int aux_q4 = get_int_b2(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); - const int k0 = kbx * (2 * QI4_NL) + kqsx; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; - constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); -#else - x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq2_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; - - const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); - const uint8_t * aux8 = (const uint8_t *) &q2; - const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); - -#pragma unroll - for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; - - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); - - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int ls = aux32 >> 28; - const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; -#else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq2_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; - - const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); - const uint16_t * q2 = (const uint16_t *) &q2_packed; - - #pragma unroll - for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int ls = bxi->scales[kqsx]; - const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#else - x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq2_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; - - const int qs_packed = get_int_b2(bxi->qs, kqsx); - const uint8_t * qs = (const uint8_t *) &qs_packed; - - const int qh = bxi->qh[kqsx]; - - const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); - const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; - -#pragma unroll - for (int l = 0; l < QR2_S; ++l) { - const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); - - const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); - const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); - - const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int ls = bxi->scales[kqsx]; - const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#else - x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq3_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; - - const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); - const uint8_t * q3 = (const uint8_t *) &q3_packed; - const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); - -#pragma unroll - for (int l = 0; l < QR3_XXS; ++l) { - const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); - - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int ls = aux32 >> 28; - const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; -#else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq3_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; - - const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); - const uint8_t * qs = (const uint8_t *) &qs_packed; - - const int qh = bxi->qh[kqsx]; - - const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); - const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; - -#pragma unroll - for (int l = 0; l < QR3_S; ++l) { - const int2 grid_pos = make_int2( - iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], - iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); - - const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); - const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); - - const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); - const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; -#else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq1_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); - int * x_qs = (int *) x_tile; - half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { - int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; - - const int qs_packed = get_int_b2(bxi->qs, kqsx); - const uint8_t * qs = (const uint8_t *) &qs_packed; - - const int qh = bxi->qh[kqsx]; - - #pragma unroll - for (int l = 0; l < QR1_S/2; ++l) { - const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; - - const int grid0 = (grid >> 0) & 0x0F0F0F0F; - const int grid1 = (grid >> 4) & 0x0F0F0F0F; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); - const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); -#else - x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template static __device__ __forceinline__ void load_tiles_iq4_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); -#else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); - int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); - constexpr int nrows = warp_size / threads_per_row; - const int kqsx = threadIdx.x % threads_per_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { - int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); - - if (need_check) { - i = min(i, i_max); - } - - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; - - const int aux_q4 = get_int_b4(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); - const int k0 = 8 * (kqsx / 4) + kqsx % 4; - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; -#else - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; - x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } - - constexpr int rows_per_warp = warp_size / 8; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { - int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; - - const float d = __half2float(bxi->d); - - const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) - | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); -#else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } -} - -template -static __device__ __forceinline__ void mmq_write_back_dp4a_id( - const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, - const int stride, const int i_max, const int j_max) { - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j > j_max) { - return; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } - } -} - -template -static __device__ __forceinline__ void mmq_write_back_mma_id( - const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, - const int stride, const int i_max, const int j_max) { - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int nwarps = mmq_get_nwarps_device(); - -#if defined(AMD_MFMA_AVAILABLE) - constexpr int tileC_IJ = mmq_get_granularity_device(0); - typedef tile tile_C; - constexpr int rows_per_warp = granularity; -#else - typedef tile<16, 8, int> tile_C; - constexpr int rows_per_warp = 2 * granularity; -#endif // defined(AMD_MFMA_AVAILABLE) - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#else - GGML_UNUSED(nwarps); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); - - if (j > j_max) { - continue; - } - - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; - } - } - } -} - -// ------------------------------------------------------------------------------------------------------------------------------------- - -template -struct mmq_type_traits_id; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; -}; - -template -static __device__ __forceinline__ void mul_mat_q_process_tile_id( - const char * __restrict__ x, const int offset_x, const int * __restrict__ y, - const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int mmq_y = get_mmq_y_device(); - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits_id::load_tiles; - - extern __shared__ int data_mul_mat_q[]; - int * tile_y = data_mul_mat_q + mmq_x; - int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); - -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_mma; - constexpr mmq_write_back_t write_back = mmq_write_back_mma_id; -#else - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_dp4a; - constexpr mmq_write_back_t write_back = mmq_write_back_dp4a_id; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - - constexpr int blocks_per_iter = MMQ_ITER_K / qk; - - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; - - for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); - - { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); -#pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { - int l = l0 + threadIdx.y*warp_size + threadIdx.x; - - tile_y[l] = by0[l]; - } - } - - __syncthreads(); - - vec_dot(tile_x, tile_y, sum, 0); - - __syncthreads(); - - { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); -#pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { - int l = l0 + threadIdx.y*warp_size + threadIdx.x; - - tile_y[l] = by0[l]; - } - } - - __syncthreads(); - - vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); - - __syncthreads(); - } - - if (fixup) { - write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); - } else { - write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); - } -} - - -// The mul_mat_q_id kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 - -template -#if defined(GGML_USE_HIP) -#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) - __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) -#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) -#else -#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) -#else - __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#endif // defined(GGML_USE_HIP) -static __global__ void mul_mat_q_id( - const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, - const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { - - // Skip unused template specializations for faster compilation: - if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { - NO_DEVICE_CODE; - return; - } - - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int mmq_y = get_mmq_y_device(); - - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y - - // Initialize the ids for writing back data with just the index. - // For regular matrix multiplications this is never changed. - // For MoE the correct indices are loaded from ids_dst. - extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { - const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - - if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { - break; - } - - ids_dst_shared[j] = j; - } - __syncthreads(); - - // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; - const int jt = blockIdx.y; - const int it = blockIdx.x; - - // Defaults for regular matrix multiplication: - int col_low = 0; - int col_high = ncols_dst; - int col_diff = ncols_dst; - int offset_y = wt*stride_sample_y + zt*stride_channel_y; - int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; - - if (ids_dst) { - col_low = expert_bounds[zt + 0]; - col_high = expert_bounds[zt + 1]; - col_diff = col_high - col_low; - - offset_y = 0; - offset_dst = 0; - - if (jt*mmq_x >= col_diff) { - return; - } - - // __syncthreads(); // There is no previous tile that could cause a race condition. -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { - const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - - if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { - break; - } - - ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; - } - __syncthreads(); - } - - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); - offset_dst += it*mmq_y; - - const int tile_x_max_i = nrows_x - it*mmq_y - 1; - const int tile_y_max_j = col_diff - jt*mmq_x - 1; - - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; - - constexpr bool fixup = false; - mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); - return; - } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; - - // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; - - // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; - - // Defaults for regular matrix multiplication: - int col_low = 0; - int col_high = ncols_dst; - int col_diff = ncols_dst; - int offset_y = wt*stride_sample_y + zt*stride_channel_y; - int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; - - if (ids_dst) { - col_low = expert_bounds[zt + 0]; - col_high = expert_bounds[zt + 1]; - col_diff = col_high - col_low; - - offset_y = 0; - offset_dst = 0; - - if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; - - kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); - - continue; - } - - __syncthreads(); -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { - const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - - if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { - break; - } - - ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; - } - __syncthreads(); - } - - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); - offset_dst += it*mmq_y; - - const int tile_x_max_i = nrows_x - it*mmq_y - 1; - const int tile_y_max_j = col_diff - jt*mmq_x - 1; - - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; - - constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. - mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; - - kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); - } - - if (kbc >= kbc_stop) { - return; - } - - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; - - // Defaults for regular matrix multiplication: - int col_low = 0; - int col_high = ncols_dst; - int col_diff = ncols_dst; - int offset_y = wt*stride_sample_y + zt*stride_channel_y; - int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; - - if (ids_dst) { - col_low = expert_bounds[zt + 0]; - col_high = expert_bounds[zt + 1]; - col_diff = col_high - col_low; - - offset_y = 0; - offset_dst = 0; - - if (jt*mmq_x >= col_diff) { - return; - } - - // The memory layout for the fixup buffer is always contiguous, therefore reset ids: - __syncthreads(); -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { - const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - - if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { - break; - } - - ids_dst_shared[j] = j; - } - __syncthreads(); - } - - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); - offset_dst += it*mmq_y; - - const int tile_x_max_i = nrows_x - it*mmq_y - 1; - const int tile_y_max_j = col_diff - jt*mmq_x - 1; - - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; - - constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. - mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); -} - - -template -static __global__ void mul_mat_q_stream_k_fixup_id( - const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; - - constexpr int nwarps = mmq_get_nwarps_device(); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; - - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; - - const int bidx0 = blockIdx.x; - - // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; - - const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; - if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { - return; - } - - bool any_fixup = false; - - // Iterate over previous blocks and sum up partial sums written to fixup buffer. - // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; - while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - - if (kbc == kbc_stop) { // Did not have any data. - bidx--; - kbc_stop = kbc; - continue; - } - - any_fixup = true; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } - } - - // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { - break; - } - bidx--; - kbc_stop = kbc; - } - - if (!any_fixup) { - return; - } - - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; - - if (!ids_dst) { - const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; - dst += offset_dst; - - const int i_max = nrows_x - it*mmq_y - 1; - const int j_max = ncols_dst - jt*mmq_x - 1; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j > j_max) { - return; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } - } - return; - } - - __shared__ int ids_dst_shared[mmq_x]; - const int col_low = expert_bounds[zt + 0]; - const int col_high = expert_bounds[zt + 1]; - const int col_diff = col_high - col_low; - - for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { - ids_dst_shared[j] = ids_dst[col_low + j]; - } - __syncthreads(); - - const int offset_dst = it*mmq_y; - dst += offset_dst; - - const int i_max = nrows_x - it*mmq_y - 1; - const int j_max = col_diff - jt*mmq_x - 1; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j > j_max) { - return; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } - } -} - -struct mmq_args_id { - const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; - int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; - int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; - int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; - bool use_stream_k; int64_t ncols_max; -}; - -template -static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { - const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); - const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); - return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); -} - -template -static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int nsm = ggml_cuda_info().devices[id].nsm; - const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc, warp_size); - const int mmq_y = get_mmq_y_host(cc); - - const dim3 block_dims(warp_size, nwarps, 1); - - const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); - - const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; - const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; - const int ntzw = args.nchannels_y * args.nsamples_y; - const dim3 block_nums_xy_tiling(nty, ntx, ntzw); - - if (args.nchannels_y % args.nchannels_x) { - printf("Oops: args.nchannels_y = %d, args.nchannels_x = %d\n", args.nchannels_y, args.nchannels_x); - } - GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); - GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); - const int channel_ratio = args.nchannels_y / args.nchannels_x; - const int sample_ratio = args.nsamples_y / args.nsamples_x; - - if (!args.use_stream_k) { - if (args.nrows_x % mmq_y == 0) { - constexpr bool need_check = false; - mul_mat_q_id<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); - } else { - constexpr bool need_check = true; - mul_mat_q_id<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); - } - return; - } - - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; - - ggml_cuda_pool & pool = ctx.pool(id); - ggml_cuda_pool_alloc tmp_fixup(pool); - if (fixup_needed) { - tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); - } - - if (args.nrows_x % mmq_y == 0) { - constexpr bool need_check = false; - mul_mat_q_id<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); - - if (!fixup_needed) { - return; - } - - mul_mat_q_stream_k_fixup_id<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); - } else { - constexpr bool need_check = true; - mul_mat_q_id<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); - - if (!fixup_needed) { - return; - } - - mul_mat_q_stream_k_fixup_id<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); - } -} - -template -void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc, warp_size); - - const int mmq_x_max = get_mmq_x_max_host(cc); - const int mmq_y = get_mmq_y_host(cc); - - int mmq_x_best = 0; - int ntiles_x_best = INT_MAX; - - for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { - const int granularity = mmq_get_granularity_host(mmq_x, cc); - - if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { - continue; - } - - const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; - - if (ntiles_x < ntiles_x_best) { - mmq_x_best = mmq_x; - ntiles_x_best = ntiles_x; - } - } - - switch (mmq_x_best) { - case 8: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 16: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 24: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 32: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 40: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 48: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 56: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 64: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 72: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 80: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 88: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 96: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 104: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 112: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 120: - launch_mul_mat_q_id(ctx, args, stream); - break; - case 128: - launch_mul_mat_q_id(ctx, args, stream); - break; - default: - fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); - GGML_ABORT("fatal error"); - break; - } -} - -#define DECL_MMQ_CASE(type) \ - template void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) \ - -DECL_MMQ_CASE(GGML_TYPE_Q4_0); -DECL_MMQ_CASE(GGML_TYPE_Q4_1); -DECL_MMQ_CASE(GGML_TYPE_Q5_0); -DECL_MMQ_CASE(GGML_TYPE_Q5_1); -DECL_MMQ_CASE(GGML_TYPE_Q8_0); -DECL_MMQ_CASE(GGML_TYPE_MXFP4); -DECL_MMQ_CASE(GGML_TYPE_Q2_K); -DECL_MMQ_CASE(GGML_TYPE_Q3_K); -DECL_MMQ_CASE(GGML_TYPE_Q4_K); -DECL_MMQ_CASE(GGML_TYPE_Q5_K); -DECL_MMQ_CASE(GGML_TYPE_Q6_K); -DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); -DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); -DECL_MMQ_CASE(GGML_TYPE_IQ2_S); -DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); -DECL_MMQ_CASE(GGML_TYPE_IQ3_S); -DECL_MMQ_CASE(GGML_TYPE_IQ1_S); -DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); -DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); - -// ------------------------------------------------------------------------------------------------------------------------- - -static bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); +#include +#include +#include // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. struct mmq_ids_helper_store { @@ -4326,76 +399,3 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream); } - -bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { -#ifdef GGML_CUDA_FORCE_CUBLAS - return false; -#endif // GGML_CUDA_FORCE_CUBLAS - - bool mmq_supported; - - switch (type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - mmq_supported = true; - break; - default: - mmq_supported = false; - break; - } - - if (!mmq_supported) { - return false; - } - - if (turing_mma_available(cc)) { - return true; - } - - if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { - return false; - } - -#ifdef GGML_CUDA_FORCE_MMQ - return true; -#endif //GGML_CUDA_FORCE_MMQ - - if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } - - if (amd_mfma_available(cc)) { - // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) - // performs better but is currently suffering from a crash on this architecture. - // TODO: Revisit when hipblaslt is fixed on CDNA3 - if (GGML_CUDA_CC_IS_CDNA3(cc)) { - return true; - } - if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { - return true; - } - if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { - return true; - } - return false; - } - - return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; -} diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh new file mode 100644 index 000000000..a7ae241ee --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -0,0 +1,3932 @@ +#pragma once + +#include "common.cuh" +#include "mma_new.cuh" +#include "vecdotq.cuh" + +#include +#include +#include + +using namespace ggml_cuda_mma; + +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 + +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); + +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + +struct block_q8_1_mmq { + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); + +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_MXFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ABORT("fatal error"); + break; + } +} + +struct tile_x_sizes { + int qs; + int dm; + int sc; +}; + +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 + +// AMD +// GCN/CDNA, wave size is 64 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 + +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 + +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) + +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD + +#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) + +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11010 + +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define FP16_MMA_AVAILABLE +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) + +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#define TURING_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#ifdef __CUDACC__ +template +__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} +#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) +#else +#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) +#endif // __CUDACC__ + +static bool amd_mfma_available(const int cc) { +#if !defined(GGML_HIP_NO_MMQ_MFMA) + return GGML_CUDA_CC_IS_CDNA(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_MMQ_MFMA) +} +static bool turing_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static int get_mmq_x_max_host(const int cc) { + return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_VOLTA ? +#ifdef GGML_CUDA_FORCE_MMQ + 128 : 64; +#else + MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static constexpr int ggml_cuda_get_physical_warp_size_host() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { +#if CUDART_VERSION >= 12080 + const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); + return (float) e; +#else + uint32_t bits; + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +#endif // CUDART_VERSION >= 12050 +} +template + +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; + } +} +template +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +} +template +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); + } + return a; +} +template +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#ifdef FP16_AVAILABLE +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + +static bool fp16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); +} + + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + return 128; +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#if defined(GGML_USE_HIP) + return 64; +#else // defined(GGML_USE_HIP) + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return 128; +#else // GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + +#endif // defined(GGML_USE_HIP) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +static int get_mmq_y_host(const int cc) { + return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); +} + +static constexpr __device__ int get_mmq_y_device() { +#if defined(GGML_USE_HIP) +#if defined(RDNA1) + return 64; +#else + return 128; +#endif // defined RDNA1 +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 128; +#else + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +} + +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } +} + +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } +} + +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (turing_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } +} + +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(TURING_MMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { + return 8; +} +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; +} +#endif // (GGML_USE_HIP) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 256/ggml_cuda_get_physical_warp_size(); +#endif // AMD_MFMA_AVAILABLE +} + +// ------------------------------------------------------------ + +template static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b4(bxi->qs, kqsx); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI_MXFP4; + const int kqsx = txi % QI_MXFP4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = kbx * (2 * QI_MXFP4) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#else + x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + tile_B B; + float dB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B; + float2 dsB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +// Used for Q3_K, IQ2_S, and IQ2_XS +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +// Used for Q3_K, IQ2_S, and IQ2_XS: +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum), GGML_UNUSED(k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + +#pragma unroll + for (int l = 0; l < QR2_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int sc_m = bxi->scales[kqsx]; +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; +#else + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + float2 y_df[mmq_x/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { + const int k0 = k00 + k01; + + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B[2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + + tile_C Cm[2]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd[2]; + + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { + float2 sB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); + +#pragma unroll + for (int l = 0; l < QR3_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int ksc = threadIdx.x % 4; + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; + } +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, txi); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*txi; + + const int ql = get_int_b4(bxi->qs, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + + const int ql = get_int_b2(bxi->ql, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; + + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { + const int k0 = k00 + k01; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; + } + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_iq4_nl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b2(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#else + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 8; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const float d = __half2float(bxi->d); + + const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) + | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void mmq_write_back_dp4a_id( + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +template +static __device__ __forceinline__ void mmq_write_back_mma_id( + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; + constexpr int rows_per_warp = 2 * granularity; +#endif // defined(AMD_MFMA_AVAILABLE) + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); + + if (j > j_max) { + continue; + } + + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + } + } + } +} + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template +struct mmq_type_traits_id; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +static __device__ __forceinline__ void mul_mat_q_process_tile_id( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits_id::load_tiles; + + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma_id; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a_id; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); + + __syncthreads(); + } + + if (fixup) { + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q_id kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + +template +#if defined(GGML_USE_HIP) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) +#else + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +static __global__ void mul_mat_q_id( + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { + NO_DEVICE_CODE; + return; + } + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; + mul_mat_q_process_tile_id + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + return; + } +#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + + const int64_t blocks_per_ne00 = ncols_x / qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile_id + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile_id + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); +} + + +template +static __global__ void mul_mat_q_stream_k_fixup_id( + const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, + const int ncols_max) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; + + const int bidx0 = blockIdx.x; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + bool any_fixup = false; + + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + any_fixup = true; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + } + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; + } + + if (!any_fixup) { + return; + } + + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_dst - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { + ids_dst_shared[j] = ids_dst[col_low + j]; + } + __syncthreads(); + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +struct mmq_args_id { + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; + bool use_stream_k; int64_t ncols_max; +}; + +template +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); +} + +template +static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + const int mmq_y = get_mmq_y_host(cc); + + const dim3 block_dims(warp_size, nwarps, 1); + + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); + + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); + + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + if (args.nchannels_y % args.nchannels_x) { + printf("Oops: args.nchannels_y = %d, args.nchannels_x = %d\n", args.nchannels_y, args.nchannels_x); + } + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; + + if (!args.use_stream_k) { + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } + return; + } + + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } + + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup_id<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup_id<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } +} + +template +void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc); + + int mmq_x_best = 0; + int ntiles_x_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { + continue; + } + + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; + + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 16: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 24: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 32: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 40: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 48: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 56: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 64: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 72: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 80: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 88: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 96: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 104: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 112: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 120: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 128: + launch_mul_mat_q_id(ctx, args, stream); + break; + default: + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); + GGML_ABORT("fatal error"); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); + +// ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmq_id_kernels.cu b/ggml/src/ggml-cuda/mmq_id_kernels.cu new file mode 100644 index 000000000..722320adf --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id_kernels.cu @@ -0,0 +1,22 @@ +#include "mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); +DECL_MMQ_CASE(GGML_TYPE_Q4_1); +DECL_MMQ_CASE(GGML_TYPE_Q5_0); +DECL_MMQ_CASE(GGML_TYPE_Q5_1); +DECL_MMQ_CASE(GGML_TYPE_Q8_0); +DECL_MMQ_CASE(GGML_TYPE_MXFP4); +DECL_MMQ_CASE(GGML_TYPE_Q2_K); +DECL_MMQ_CASE(GGML_TYPE_Q3_K); +DECL_MMQ_CASE(GGML_TYPE_Q4_K); +DECL_MMQ_CASE(GGML_TYPE_Q5_K); +DECL_MMQ_CASE(GGML_TYPE_Q6_K); +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); + From 26c0dfdbfa7d9be339288d285d8214cf2622928e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 24 Aug 2025 20:03:55 +0300 Subject: [PATCH 05/23] WIP --- ggml/src/ggml-cuda/mmq_id_common.cuh | 3 --- ggml/src/ggml-cuda/mmq_id_kernels.cu | 22 ------------------- .../mmq-instance-iq1_s_id.cu | 5 +++++ .../mmq-instance-iq2_s_id.cu | 5 +++++ .../mmq-instance-iq2_xs_id.cu | 5 +++++ .../mmq-instance-iq2_xxs_id.cu | 5 +++++ .../mmq-instance-iq3_s_id.cu | 5 +++++ .../mmq-instance-iq3_xxs_id.cu | 5 +++++ .../mmq-instance-iq4_nl_id.cu | 4 ++++ .../mmq-instance-iq4_xs_id.cu | 5 +++++ .../mmq-instance-q2_k_id.cu | 5 +++++ .../mmq-instance-q3_k_id.cu | 5 +++++ .../mmq-instance-q4_0_id.cu | 5 +++++ .../mmq-instance-q4_1_id.cu | 5 +++++ .../mmq-instance-q4_k_id.cu | 5 +++++ .../mmq-instance-q5_0_id.cu | 5 +++++ .../mmq-instance-q5_1_id.cu | 5 +++++ .../mmq-instance-q5_k_id.cu | 5 +++++ .../mmq-instance-q6_k_id.cu | 5 +++++ .../mmq-instance-q8_0_id.cu | 5 +++++ 20 files changed, 89 insertions(+), 25 deletions(-) delete mode 100644 ggml/src/ggml-cuda/mmq_id_kernels.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index a7ae241ee..931d144b3 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -3746,9 +3746,6 @@ static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_ const int ntzw = args.nchannels_y * args.nsamples_y; const dim3 block_nums_xy_tiling(nty, ntx, ntzw); - if (args.nchannels_y % args.nchannels_x) { - printf("Oops: args.nchannels_y = %d, args.nchannels_x = %d\n", args.nchannels_y, args.nchannels_x); - } GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); const int channel_ratio = args.nchannels_y / args.nchannels_x; diff --git a/ggml/src/ggml-cuda/mmq_id_kernels.cu b/ggml/src/ggml-cuda/mmq_id_kernels.cu deleted file mode 100644 index 722320adf..000000000 --- a/ggml/src/ggml-cuda/mmq_id_kernels.cu +++ /dev/null @@ -1,22 +0,0 @@ -#include "mmq_id_common.cuh" - -DECL_MMQ_CASE(GGML_TYPE_Q4_0); -DECL_MMQ_CASE(GGML_TYPE_Q4_1); -DECL_MMQ_CASE(GGML_TYPE_Q5_0); -DECL_MMQ_CASE(GGML_TYPE_Q5_1); -DECL_MMQ_CASE(GGML_TYPE_Q8_0); -DECL_MMQ_CASE(GGML_TYPE_MXFP4); -DECL_MMQ_CASE(GGML_TYPE_Q2_K); -DECL_MMQ_CASE(GGML_TYPE_Q3_K); -DECL_MMQ_CASE(GGML_TYPE_Q4_K); -DECL_MMQ_CASE(GGML_TYPE_Q5_K); -DECL_MMQ_CASE(GGML_TYPE_Q6_K); -DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); -DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); -DECL_MMQ_CASE(GGML_TYPE_IQ2_S); -DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); -DECL_MMQ_CASE(GGML_TYPE_IQ3_S); -DECL_MMQ_CASE(GGML_TYPE_IQ1_S); -DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); -DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); - diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu new file mode 100644 index 000000000..9c04a0205 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu new file mode 100644 index 000000000..6f54f24e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu new file mode 100644 index 000000000..529586678 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu new file mode 100644 index 000000000..c735ab977 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu new file mode 100644 index 000000000..5c501ed36 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu new file mode 100644 index 000000000..387e6d274 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu new file mode 100644 index 000000000..6ef74f1d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu @@ -0,0 +1,4 @@ +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +DECL_MMQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu new file mode 100644 index 000000000..988f5da1a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu new file mode 100644 index 000000000..14d78d839 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q2_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu new file mode 100644 index 000000000..4262f49aa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu new file mode 100644 index 000000000..4e747a8ca --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu new file mode 100644 index 000000000..a29a943ed --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu new file mode 100644 index 000000000..c16f76138 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu new file mode 100644 index 000000000..3afa037ad --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu new file mode 100644 index 000000000..1c161297f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu new file mode 100644 index 000000000..36c7a9fae --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu new file mode 100644 index 000000000..2a02f965a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu new file mode 100644 index 000000000..3b126bed9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q8_0); From c4f2af9ccc42ee77fd9ae1ce29c8e0d5d69a75b3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 06:52:33 +0300 Subject: [PATCH 06/23] WIP --- ggml/src/ggml-cuda.cu | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index c234ec07b..b0d5620f9 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2718,13 +2718,34 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dst_row.data = dst_up_contiguous.get(); ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } dst_row.data = dst_gate_contiguous.get(); ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, ggml_nelements(dst), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + } + CUDA_CHECK(cudaGetLastError()); + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { From e9afb0b8fc8acba2b6f6c6ee1b2bdb323546e144 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 07:46:20 +0300 Subject: [PATCH 07/23] This works for mainline supported quants --- ggml/src/ggml-cuda.cu | 31 +++++++--------- ggml/src/ggml-cuda/mmq_id.cu | 70 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmq_id.cuh | 1 + 3 files changed, 85 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b0d5620f9..e0eea7704 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2394,13 +2394,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } - //printf("src0(%s): %ld x %ld x %ld, src1: %ld x %ld x %ld dst: ids: %ld x %ld x %ld, %ld x %ld x %ld\n", - // src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], - // ids->ne[0], ids->ne[1], ids->ne[2], dst->ne[0], dst->ne[1], dst->ne[2]); - - ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); - return false; - + if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); + return false; + } GGML_TENSOR_BINARY_OP_LOCALS @@ -2679,19 +2676,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor cudaStream_t stream = ctx.stream(); - ggml_tensor src0_1_row = *src0_1; - ggml_tensor src0_2_row = *src0_2; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; - ggml_tensor final_dst; - ggml_tensor final_src; - const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; - if (src1->ne[2] <= 2048 && + ggml_tensor dst_row = *dst; + + if (src1->ne[2] <= 2048 && // TODO: this depends on number of total vs number of active experts -> need to find optimum threshod ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && - ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { const int64_t ne_get_rows = ne12 * n_ids; ggml_cuda_pool_alloc ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1); @@ -2746,7 +2738,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); - if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr); @@ -2762,6 +2753,12 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); + ggml_tensor src0_1_row = *src0_1; + ggml_tensor src0_2_row = *src0_2; + ggml_tensor src1_row = *src1; + ggml_tensor final_dst; + ggml_tensor final_src; + char * src0_1_original = (char *) src0_1->data; char * src0_2_original = (char *) src0_2->data; char * src1_original = (char *) src1->data; diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 17e6798bf..8689394ec 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -399,3 +399,73 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream); } + +bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { + bool mmq_supported; + + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + mmq_supported = true; + break; + default: + mmq_supported = false; + break; + } + + if (!mmq_supported) { + return false; + } + + if (turing_mma_available(cc)) { + return true; + } + + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + +} diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh index c85c468fe..567393074 100644 --- a/ggml/src/ggml-cuda/mmq_id.cuh +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -9,3 +9,4 @@ void ggml_cuda_mul_mat_q_id( void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); +bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11); From 16e477a9451ca7f503a4ecf8f6d638e19dfcf1b6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 08:52:32 +0300 Subject: [PATCH 08/23] mmq_id: add iq2_k, iq2_k_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 8 + ggml/src/ggml-cuda/mmq_id_common.cuh | 59 ++++-- .../mmq-instance-iq2_k_id.cu | 200 ++++++++++++++++++ 3 files changed, 248 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 8689394ec..971adf0c5 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -198,6 +198,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ4_NL: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ2_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -423,6 +429,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 931d144b3..714010f54 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -3,6 +3,7 @@ #include "common.cuh" #include "mma_new.cuh" #include "vecdotq.cuh" +#include "iqk_cuda_common.h" #include #include @@ -79,6 +80,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -368,6 +371,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + // ================= ik_llama.cpp quants + case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; } } @@ -405,6 +411,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + // ================= ik_llama.cpp quants + case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; } } @@ -3071,6 +3080,15 @@ static __device__ __forceinline__ void mmq_write_back_mma_id( } } +// ===================================== ik_llama.cpp quants ============================================================= + +// +// Strictly speaking, we should bite the bullet and change WARP_SIZE to warp_size or MMQ_TILE_NE_K. +// But as we basically don't support anything but Nvidia in the CUDA backend, we alays have +// WARP_SIZE = MMQ_TILE_NE_K = 32 +// + + // ------------------------------------------------------------------------------------------------------------------------------------- template @@ -3078,7 +3096,7 @@ struct mmq_type_traits_id; template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; @@ -3086,7 +3104,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; @@ -3094,7 +3112,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3102,7 +3120,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; @@ -3110,7 +3128,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3118,7 +3136,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; + //static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3126,7 +3144,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; @@ -3134,7 +3152,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; @@ -3142,7 +3160,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; @@ -3150,7 +3168,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; @@ -3158,7 +3176,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + //static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; @@ -3166,7 +3184,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3174,7 +3192,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3182,7 +3200,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3190,7 +3208,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3198,7 +3216,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3206,7 +3224,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; @@ -3214,7 +3232,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3222,7 +3240,7 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + //static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3925,5 +3943,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +// =================== ik_llama.cpp quants +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu new file mode 100644 index 000000000..fef36049f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu @@ -0,0 +1,200 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + //constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_k * bxi = (const block_iq2_k *)x + i*stride + kbx0; + + const float d = bxi->d; + uint16_t extra = bxi->extra >> (kqsx/4); + +#ifdef __CUDA_ARCH__ + + uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 }; + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 2) & 0x44444444); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 0) & 0x44444444); + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else + + auto all_values = (const int *)iq2k_table; + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); +#endif // INT8_MMA_AVAILABLE + + extra >>= 8; + } +#endif // __CUDA_ARCH__ + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq2_k_r4 * bxi = (const block_iq2_k_r4 *)x + 4*i4*stride + kbx0; + + const float d = __half2float(bxi->d[ir]); + +#ifdef __CUDA_ARCH__ + #pragma unroll + for (int l = 0; l < 2; ++l) { + + uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404; + extra = extra | (extra << 4); + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra; + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra; + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8); + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); +#endif // INT8_MMA_AVAILABLE + } +#endif // __CUDA_ARCH__ + + int is = 8*kqsx + ir; + float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8); + is += 4; + float dl2 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_K); +DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); From 9031898cfdc3521e6536461f4ce9549b11d7c478 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 13:42:09 +0300 Subject: [PATCH 09/23] mmiq_id: don't assume row size is multiple of type size (per row scales) --- ggml/src/ggml-cuda/mmq_id.cu | 10 +++--- ggml/src/ggml-cuda/mmq_id_common.cuh | 35 ++++++++++--------- .../mmq-instance-iq2_k_id.cu | 4 +-- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 971adf0c5..e95ee04c9 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -261,11 +261,11 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * cudaStream_t stream = ctx.stream(); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const size_t ts_src0 = ggml_type_size(src0->type); + //const size_t ts_src0 = ggml_type_size(src0->type); const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT( nb00 == ts_src0); + //GGML_ASSERT( nb00 == ts_src0); GGML_ASSERT( nb10 == ts_src1); GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(ids_tensor->nb[0] == ggml_type_size(ids_tensor->type)); @@ -291,11 +291,11 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); - const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s01 = src0->nb[1];// / ts_src0; const int64_t s1 = dst->nb[1] / ts_dst; - const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s02 = src0->nb[2];// / ts_src0; const int64_t s2 = dst->nb[2] / ts_dst; - const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s03 = src0->nb[3];// / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 714010f54..52e2dc870 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -493,7 +493,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -516,7 +516,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -854,7 +854,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); @@ -877,7 +877,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -2243,7 +2243,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; const int ql = get_int_b2(bxi->ql, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; @@ -2273,7 +2273,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; @@ -2525,7 +2525,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); @@ -2552,7 +2552,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); @@ -3248,7 +3248,7 @@ struct mmq_type_traits_id { template static __device__ __forceinline__ void mul_mat_q_process_tile_id( - const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const char * __restrict__ x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { @@ -3276,7 +3276,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile_id( float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); + load_tiles(x, tile_x, kb0, tile_x_max_i, stride_row_x); { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); @@ -3419,11 +3419,12 @@ static __global__ void mul_mat_q_id( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); constexpr bool fixup = false; mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } @@ -3497,11 +3498,12 @@ static __global__ void mul_mat_q_id( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); kbc += blocks_per_ne00; @@ -3564,11 +3566,12 @@ static __global__ void mul_mat_q_id( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile_id - (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu index fef36049f..3c3fd4c9d 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu @@ -27,7 +27,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq2_k * bxi = (const block_iq2_k *)x + i*stride + kbx0; + const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0; const float d = bxi->d; uint16_t extra = bxi->extra >> (kqsx/4); @@ -115,7 +115,7 @@ template static __device__ __forceinline__ void loa int i4 = i/4; int ir = i%4; - const block_iq2_k_r4 * bxi = (const block_iq2_k_r4 *)x + 4*i4*stride + kbx0; + const block_iq2_k_r4 * bxi = (const block_iq2_k_r4 *)(x + 4*i4*stride) + kbx0; const float d = __half2float(bxi->d[ir]); From d9114301c04256cf1f2a375d996f5e0a0eb92330 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 13:53:51 +0300 Subject: [PATCH 10/23] mmiq_id: don't assume row size is multiple of type size --- ggml/src/ggml-cuda/mmq_id_common.cuh | 58 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 52e2dc870..3a6b668c5 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -596,7 +596,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -619,7 +619,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -699,7 +699,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); @@ -739,7 +739,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -777,7 +777,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b4(bxi->qs, kqsx); const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); @@ -815,7 +815,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -915,7 +915,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b1(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); @@ -942,7 +942,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; @@ -1476,7 +1476,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); @@ -1795,7 +1795,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); @@ -1826,7 +1826,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int ksc = threadIdx.x % 4; @@ -1862,7 +1862,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; x_df[i] = bxi->d; } @@ -1941,7 +1941,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int qs0 = get_int_b4(bxi->qs, txi); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -1970,7 +1970,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % 2; @@ -1998,7 +1998,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -2011,7 +2011,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; @@ -2085,7 +2085,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int ky = QR5_K*txi; const int ql = get_int_b4(bxi->qs, txi); @@ -2126,7 +2126,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % 2; @@ -2154,7 +2154,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -2168,7 +2168,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; @@ -2291,7 +2291,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); @@ -2588,7 +2588,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0; const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); const uint8_t * aux8 = (const uint8_t *) &q2; @@ -2650,7 +2650,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0; const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint16_t * q2 = (const uint16_t *) &q2_packed; @@ -2710,7 +2710,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2777,7 +2777,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0; const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * q3 = (const uint8_t *) &q3_packed; @@ -2837,7 +2837,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0; const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2904,7 +2904,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2964,7 +2964,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); @@ -2988,7 +2988,7 @@ template static __device__ __forceinline__ void loa i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; const float d = __half2float(bxi->d); From 4ac23588c158ba50ea510620809cde42c240dadc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 14:22:24 +0300 Subject: [PATCH 11/23] mmq_id: add iq2_ks So we are sure it works with per row scales --- ggml/src/ggml-cuda/mmq_id.cu | 4 + ggml/src/ggml-cuda/mmq_id_common.cuh | 4 + .../mmq-instance-iq2_ks_id.cu | 114 ++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index e95ee04c9..14d6d030a 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -198,6 +198,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ4_NL: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ2_KS: + mul_mat_q_case_id(ctx, args, stream); + break; case GGML_TYPE_IQ2_K: mul_mat_q_case_id(ctx, args, stream); break; @@ -429,6 +432,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: mmq_supported = true; diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 3a6b668c5..0e9285d25 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -80,6 +80,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; @@ -372,6 +373,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; // ================= ik_llama.cpp quants + case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; @@ -412,6 +414,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; // ================= ik_llama.cpp quants + case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; @@ -3947,6 +3950,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // =================== ik_llama.cpp quants +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu new file mode 100644 index 000000000..87759b593 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu @@ -0,0 +1,114 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x%16; + +#ifdef __CUDA_ARCH__ + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) { + int i = i0 + 2*threadIdx.y + threadIdx.x/16; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0; + + uint16_t extra = bxi->extra >> 4*(kqsx/8); + int q2 = get_int_b2(bxi->qs, kqsx); + + uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101; + uint32_t val1 = ((q2 >> 0) & 0x33333333) | ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040); + uint32_t val2 = ((q2 >> 2) & 0x33333333) | ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040); + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else // __CUDA_ARCH__ + + + const int * all_values = (const int *)iq2k_table; + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) { + int i = i0 + 2*threadIdx.y + threadIdx.x/16; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0; + + uint16_t extra = bxi->extra >> 4*(kqsx/8); + int q2 = get_int_b2(bxi->qs, kqsx); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); +#endif // INT8_MMA_AVAILABLE + } +#endif // __CUDA_ARCH__ + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = i0 + threadIdx.y * 8 + threadIdx.x / 4; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = dptr[0]; + const block_iq2_ks * bxi = (const block_iq2_ks *)(dptr + 1) + kbx0; + const int ls1 = ((bxi->scales[threadIdx.x % 4] >> 0) & 0xf) | ((bxi->extra >> (4 + 2*(threadIdx.x % 4))) & 0x10); + const int ls2 = ((bxi->scales[threadIdx.x % 4] >> 4) & 0xf) | ((bxi->extra >> (5 + 2*(threadIdx.x % 4))) & 0x10); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + 2*(threadIdx.x % 4) + 0] = d * (ls1 - 16); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + 2*(threadIdx.x % 4) + 1] = d * (ls2 - 16); +#else + x_df[i*(WARP_SIZE/4) + i/4 + 2*(threadIdx.x % 4) + 0] = d * (ls1 - 16); + x_df[i*(WARP_SIZE/4) + i/4 + 2*(threadIdx.x % 4) + 1] = d * (ls2 - 16); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); From 45497a0209dc72c1949d6a8594276ec1a8db270b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 14:36:49 +0300 Subject: [PATCH 12/23] mmq_id: add iq2_kl --- ggml/src/ggml-cuda/mmq_id.cu | 4 ++ ggml/src/ggml-cuda/mmq_id_common.cuh | 4 ++ .../mmq-instance-iq2_kl_id.cu | 72 +++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 14d6d030a..79c73231e 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -201,6 +201,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ2_KS: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ2_KL: + mul_mat_q_case_id(ctx, args, stream); + break; case GGML_TYPE_IQ2_K: mul_mat_q_case_id(ctx, args, stream); break; @@ -433,6 +436,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: mmq_supported = true; diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 0e9285d25..d2d78890a 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -81,6 +81,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; @@ -374,6 +375,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; // ================= ik_llama.cpp quants case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; @@ -415,6 +417,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; // ================= ik_llama.cpp quants case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; @@ -3951,6 +3954,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // =================== ik_llama.cpp quants extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu new file mode 100644 index 000000000..66e51797e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu @@ -0,0 +1,72 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_kl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = *dptr; + const block_iq2_kl * bxi = (const block_iq2_kl *)(dptr + 1) + kbx0; + + #pragma unroll + for (int j = 0; j < 2; ++j) { + auto ql = get_int_b2(bxi->qs, 4*(kqsx/2) + 2*(kqsx%2) + j); + auto qh = get_int_b2(bxi->qh, 2*(kqsx%2) + j) >> 2*(kqsx/2); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + #pragma unroll + for (int l = 0; l < 2; ++l) { + int val1 = iq2kl_values[a8[2*l+0]] | (iq2kl_values[a8[2*l+1]] << 16); + int val2 = iq2kl_values[a8[2*l+4]] | (iq2kl_values[a8[2*l+5]] << 16); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2; +#else + x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1; + x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2; +#endif + } + } + + int ls = int(((bxi->scales_l[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->scales_h >> 2*kqsx) & 3) << 4)) - 32; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * ls; +#endif + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); From 66c31e17b3584de00e548079dd053eb8872e8c9b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 14:48:04 +0300 Subject: [PATCH 13/23] mmq_id: add iq3_ks --- ggml/src/ggml-cuda/mmq_id.cu | 4 + ggml/src/ggml-cuda/mmq_id_common.cuh | 4 + .../mmq-instance-iq3_ks_id.cu | 79 +++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 79c73231e..d4ed30361 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -210,6 +210,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ2_K_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ3_KS: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -439,6 +442,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_KS: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index d2d78890a..5435c9f9f 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -84,6 +84,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_KS: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -378,6 +379,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } @@ -420,6 +422,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } @@ -3957,5 +3960,6 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu new file mode 100644 index 000000000..1ecb748f3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu @@ -0,0 +1,79 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = __half2float(dptr[0]); + const block_iq3_ks * bxi = (const block_iq3_ks *)(dptr + 1) + kbx0; + + //uint16_t extra = bxi->extra >> 8; + int qh = get_int_b2(bxi->qh, kqsx); + + uint32_t extra32 = uint32_t(bxi->extra >> 8) * 0x01010101; + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((qh << 2) & 0x04040404) | ((extra32 << 3) & 0x08080808) + | ((qh << 4) & 0x40404040) | ((extra32 << 5) & 0x80808080); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((qh << 1) & 0x04040404) | ((extra32 << 2) & 0x08080808) + | ((qh << 3) & 0x40404040) | ((extra32 << 4) & 0x80808080); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + + extra32 >>= 4; + qh >>= 4; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * (int(((bxi->scales[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->extra >> kqsx) & 1) << 4)) - 16); +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * (int(((bxi->scales[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->extra >> kqsx) & 1) << 4)) - 16); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); From 951971080d2e929bafc99040823efa01e15f6526 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 15:03:58 +0300 Subject: [PATCH 14/23] mmq_id: adding iq3_k, iq3_k_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 8 + ggml/src/ggml-cuda/mmq_id_common.cuh | 8 + .../mmq-instance-iq3_k_id.cu | 164 ++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index d4ed30361..be807aea5 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -210,6 +210,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ2_K_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ3_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; case GGML_TYPE_IQ3_KS: mul_mat_q_case_id(ctx, args, stream); break; @@ -443,6 +449,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_KS: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 5435c9f9f..845737be7 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -85,6 +85,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_KS: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -380,6 +382,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; } } @@ -423,6 +427,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; } } @@ -3961,5 +3967,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu new file mode 100644 index 000000000..7e7ed12a2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu @@ -0,0 +1,164 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_k * bxi = (const block_iq3_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + + uint16_t extra = bxi->extra >> (kqsx/4); + uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 }; + int qh = get_int_b2(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + //extra << 3, extra << 1, extra >> 1, extra >> 3 + const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 3) & 0x88888888) + | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 1) & 0x88888888) + | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + + qh >>= 4; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + + uint16_t sh = bxi->scales_h >> 2*kqsx; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq3_k_r4 * bxi = (const block_iq3_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + int qh = get_int_b4(bxi->qh, 4*kqsx+ir); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + //auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6); + uint32_t extra32 = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x88888888; + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra32 | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra32 | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; +#endif // INT8_MMA_AVAILABLE + + qh >>= 4; + } + + int is = 8*kqsx + ir; + float dl1 = d * (2*((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((bxi->scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + is += 4; + float dl2 = d * (2*((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((bxi->scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + + +DECL_MMQ_CASE(GGML_TYPE_IQ3_K); +DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); From b0afe8dc20f265ea72e405fb6abc4062634f2a05 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 15:31:15 +0300 Subject: [PATCH 15/23] mmq_id: add iq4_kss, iq4_ks, iq4_ks_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 12 ++ ggml/src/ggml-cuda/mmq_id_common.cuh | 12 ++ .../mmq-instance-iq4_ks_id.cu | 187 ++++++++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index be807aea5..edc086e61 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -219,6 +219,15 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ3_KS: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ4_KSS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KS_R4: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -451,6 +460,9 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 845737be7..0ac6346bf 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -87,6 +87,9 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -384,6 +387,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS_R4: return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } @@ -429,6 +435,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } @@ -3969,5 +3978,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu new file mode 100644 index 000000000..ae4b4aa08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu @@ -0,0 +1,187 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_kss( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x / 4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kss * bxi = (const block_iq4_kss *)(dptr + 1) + kbx0; + const uint32_t * q4 = bxi->qs + 4*kqsx; + uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); + uint8_t ls = (s32 | (s32 >> 15)) & 0xff; + + auto values = iq4k_values + ((ls & 1) << 4); + + #pragma unroll + for (int j = 0; j < 4; ++j) { + uint32_t val = q4[j] & 0xfffefffe; + val = val ^ (val >> 1); + auto v = get_int_from_table_16(val, values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ((ls & 254) - 127); +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ((ls & 254) - 127); +#endif // INT8_MMA_AVAILABLE + } + +} + +template static __device__ __forceinline__ void load_tiles_iq4_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x / 4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[kqsx] & 254) - 127; + + auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4); + + #pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 4*kqsx+j); + const int2 v = get_int_from_table_16(q4, values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls; +#endif // INT8_MMA_AVAILABLE + } + +} + +template static __device__ __forceinline__ void load_tiles_iq4_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4); + +#pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir); + const int2 v = get_int_from_table_16(q4, values); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kss; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS); +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); + From 20ff71642854d6e3cbd5c7809642d564a714806f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 15:48:29 +0300 Subject: [PATCH 16/23] mmq_id: adding iq4_k, iq4_k_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 8 + ggml/src/ggml-cuda/mmq_id_common.cuh | 8 + .../mmq-instance-iq4_k_id.cu | 162 ++++++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index edc086e61..159018fe0 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -228,6 +228,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ4_KS_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ4_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -463,6 +469,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 0ac6346bf..51dd00fc9 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -90,6 +90,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -390,6 +392,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS_R4: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; } } @@ -438,6 +442,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; } } @@ -3981,5 +3987,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu new file mode 100644 index 000000000..f2d267fac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu @@ -0,0 +1,162 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0; + const uint16_t extra = bxi->extra >> 2*kqsx; + + auto values_l = iq4k_table + ((extra & 1) << 8); + auto values_h = iq4k_table + ((extra & 2) << 7); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int q4 = get_int_b4(bxi->qs, (qstep/2)*kqsx + l); + + aux32[0] = (q4 >> 0) & 0x0f0f0f0f; + aux32[1] = (q4 >> 4) & 0x0f0f0f0f; + + int val0 = int_from_table_x(aux8+0, values_l); + int val1 = int_from_table_x(aux8+4, values_h); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 0] = val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 4] = val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 0] = val0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 4] = val1; +#endif // INT8_MMA_AVAILABLE + } + + const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2); + const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32; + + const float d = bxi->d; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq4_k_r4 * bxi = (const block_iq4_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = iq4k_table + ((bxi->extra[ir+4*l] << (8 - kqsx)) & 0x100); + + const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0); + const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8); + aux32[0] = (ql1 >> 0) & 0x0f0f0f0f; + aux32[1] = (ql1 >> 4) & 0x0f0f0f0f; + aux32[2] = (ql2 >> 0) & 0x0f0f0f0f; + aux32[3] = (ql2 >> 4) & 0x0f0f0f0f; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l); +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l); +#endif // INT8_MMA_AVAILABLE + + } + + int is = 8*kqsx + ir; + float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_K); +DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); From 8f3c813ab084d4c407156294919913c59c30991b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 16:07:38 +0300 Subject: [PATCH 17/23] mmq_id: adding iq5_ks, iq5_ks_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 8 + ggml/src/ggml-cuda/mmq_id_common.cuh | 8 + .../mmq-instance-iq5_ks_id.cu | 146 ++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 159018fe0..a33a920c3 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -234,6 +234,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ4_K_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ5_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_KS_R4: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -471,6 +477,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 51dd00fc9..86900e206 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -92,6 +92,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -394,6 +396,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_KS_R4: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } @@ -444,6 +448,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } @@ -3989,5 +3995,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu new file mode 100644 index 000000000..3232fa085 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu @@ -0,0 +1,146 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq5_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq5nl_values; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq5_ks * bxi = (const block_iq5_ks *)(dptr + 1) + kbx0; + + int qh = get_int_b4(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((bxi->scales[2*l+0] & 1) * 0x20202020); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((bxi->scales[2*l+1] & 1) * 0x20202020); + qh >>= 2; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5); + + int qh = *((const int *)bxi->qh + 4*kqsx + ir); + const int * ql = (const int *)bxi->qs + 16*kqsx + ir; +#pragma unroll + for (int j = 0; j < 4; ++j) { + aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + qh >>= 2; + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); From 601c6006d974d30b2531da304f93d6220a5fb720 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 16:52:47 +0300 Subject: [PATCH 18/23] mmq_id: adding iq5_k, iq5_k_r4, q6_0 --- ggml/src/ggml-cuda/mmq_id.cu | 15 +- ggml/src/ggml-cuda/mmq_id_common.cuh | 85 +++++++++ .../mmq-instance-iq5_k_id.cu | 174 ++++++++++++++++++ .../mmq-instance-q6_0_id.cu | 5 + 4 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index a33a920c3..2de0d2c60 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -153,6 +153,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_Q5_1: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_Q6_0: + mul_mat_q_case_id(ctx, args, stream); + break; case GGML_TYPE_Q8_0: mul_mat_q_case_id(ctx, args, stream); break; @@ -240,6 +243,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ5_KS_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ5_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -450,6 +459,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: @@ -479,6 +489,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: mmq_supported = true; break; default: @@ -513,7 +525,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { if (GGML_CUDA_CC_IS_CDNA3(cc)) { return true; } - if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 + || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q6_0) { return true; } if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 86900e206..b428d4069 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q5_1: return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q8_0: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: @@ -94,6 +96,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -368,6 +372,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; @@ -398,6 +403,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; } } @@ -420,6 +427,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; @@ -450,6 +458,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; } } @@ -861,6 +871,71 @@ template static __device__ __forceinline__ void loa } } +template static __device__ __forceinline__ void load_tiles_q6_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / QI6_0; + const int kqsx = threadIdx.x % QI6_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2); + + int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030); + int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030); + qs0 = __vsubss4(qs0, 0x20202020); // subtract 32 + qs1 = __vsubss4(qs1, 0x20202020); // subtract 32 + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) { + int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3162,6 +3237,13 @@ struct mmq_type_traits_id { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits_id { //static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; @@ -3967,6 +4049,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_0); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); @@ -3997,5 +4080,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu new file mode 100644 index 000000000..853efb4c4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu @@ -0,0 +1,174 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq5_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq5nl_values; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0; + + int qh = get_int_b4(bxi->qh, kqsx); + uint16_t extra = bxi->extra >> (kqsx/4); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) * 0x20202020); // this is very slightly faster + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) * 0x08080808); // then the version below + //aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) ? 0x20202020 : 0); + //aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) ? 0x20202020 : 0); + qh >>= 2; + extra >>= 4; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + + const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2); + const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32; + + const float d = bxi->d; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq5_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq5_k_r4 * bxi = (const block_iq5_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + int qh = get_int_b4(bxi->qh, 4*kqsx + ir); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = iq5nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 5); + + const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0); + const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8); + aux32[0] = ((ql1 >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql1 >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + aux32[2] = ((ql2 >> 0) & 0x0f0f0f0f) | ((qh >> 0) & 0x10101010); + aux32[3] = ((ql2 >> 4) & 0x0f0f0f0f) | ((qh >> 1) & 0x10101010); + + const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]); + const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]); + const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]); + const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val1; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val2; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3; +#endif // INT8_MMA_AVAILABLE + + qh >>= 2; + } + + int is = 8*kqsx + ir; + float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ5_K); +DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu new file mode 100644 index 000000000..7da267a42 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_0); From a6fe757cd89fd397643b40ba608438f1f36088d2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 17:03:16 +0300 Subject: [PATCH 19/23] mmq_id: adding iq6_k --- ggml/src/ggml-cuda/mmq_id.cu | 4 + ggml/src/ggml-cuda/mmq_id_common.cuh | 4 + .../mmq-instance-iq6_k_id.cu | 80 +++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 2de0d2c60..20c9b2bcb 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -249,6 +249,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ5_K_R4: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ6_K: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -491,6 +494,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ6_K: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index b428d4069..deb93e72d 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -98,6 +98,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ6_K: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -405,6 +406,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; default: return tile_x_sizes{0, 0, 0}; } } @@ -460,6 +462,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; default: return 0; } } @@ -4082,5 +4085,6 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu new file mode 100644 index 000000000..af05bd768 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu @@ -0,0 +1,80 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq6_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq6nl_values; + int qh[2]; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq6_k * bxi = (const block_iq6_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + uint16_t extra = bxi->extra >> (kqsx/4); + + qh[0] = get_int_b4(bxi->qh, kqsx+0); + qh[1] = get_int_b4(bxi->qh, kqsx+8); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh[l/2] & 0x03030303) << 4) | ((extra & 1) * 0x40404040); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh[l/2] & 0x0c0c0c0c) << 2) | ((extra & 4) * 0x10101010); + qh[l/2] >>= 4; + extra >>= 4; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * bxi->scales[2*kqsx+0]; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * bxi->scales[2*kqsx+1]; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * bxi->scales[2*kqsx+0]; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * bxi->scales[2*kqsx+1]; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ6_K); From 00a0aad47df5537b3b76f117d084d700c1716797 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 17:15:49 +0300 Subject: [PATCH 20/23] mmq_id: add iq1_s_r4 --- ggml/src/ggml-cuda/mmq_id.cu | 4 + ggml/src/ggml-cuda/mmq_id_common.cuh | 85 ++++++++++++++++++- .../mmq-instance-iq1_s_id.cu | 1 + 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 20c9b2bcb..27974cf2a 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -195,6 +195,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ1_S: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ1_S_R4: + mul_mat_q_case_id(ctx, args, stream); + break; case GGML_TYPE_IQ4_XS: mul_mat_q_case_id(ctx, args, stream); break; @@ -476,6 +479,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ2_KS: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index deb93e72d..57514ca31 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -79,6 +79,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ3_S: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: @@ -387,9 +388,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S_R4:return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; - // ================= ik_llama.cpp quants case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; @@ -443,9 +444,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S_R4:return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; - // ================= ik_llama.cpp quants case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; @@ -3052,6 +3053,76 @@ template static __device__ __forceinline__ void loa } } +template static __device__ __forceinline__ void load_tiles_iq1_s_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / 4; + const int kqsx = threadIdx.x % 4; + + int32_t grid32[2]; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(x + 4*i4*stride + 4*sizeof(half)) + kbx0 + kbx; + + grid32[0] = iq1s_grid_gpu[bxi->qs[4*kqsx+ir] | (((bxi->qh[ir] >> 3*kqsx) & 7) << 8)]; + grid32[1] = ((grid32[0] >> 4) & 0x0f0f0f0f) << 3; + grid32[0] = (grid32[0] & 0x0f0f0f0f) << 3; + const int shift = bxi->qh[ir] & 0x8000 ? 0x09090909 : 0x07070707; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 0] = __vsubss4(grid32[0], shift); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 1] = __vsubss4(grid32[1], shift); +#else + // TODO + //x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / 4; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const half * dptr = (const half *)(x + 4*i4*stride); + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(dptr + 4) + kbx0 + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.125f * __half2float(dptr[ir]) * (((bxi->qh[ir] >> 11) & 14) + 1); +#else + // TODO + //x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3345,12 +3416,18 @@ struct mmq_type_traits_id { template struct mmq_type_traits_id { - //static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; +}; + template struct mmq_type_traits_id { //static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; @@ -4066,9 +4143,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); -// =================== ik_llama.cpp quants extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu index 9c04a0205..29d2e2ce2 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu @@ -3,3 +3,4 @@ #include "../mmq_id_common.cuh" DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); From 3d87c2d7b239fd60fd5768c93a97eaecc1abd9ce Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 18:09:55 +0300 Subject: [PATCH 21/23] mmq_id: adding iq1_kt, iq2_kt --- ggml/src/ggml-cuda/mmq_id.cu | 8 ++ ggml/src/ggml-cuda/mmq_id_common.cuh | 8 ++ .../mmq-instance-iq1_kt_id.cu | 83 ++++++++++++++++++ .../mmq-instance-iq2_kt_id.cu | 85 +++++++++++++++++++ 4 files changed, 184 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 27974cf2a..0cb3b5aa5 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -255,6 +255,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ6_K: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ1_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_KT: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -499,6 +505,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: + case GGML_TYPE_IQ2_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 57514ca31..3329afa74 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -100,6 +100,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: + case GGML_TYPE_IQ2_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -408,6 +410,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } @@ -464,6 +468,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } @@ -4163,5 +4169,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu new file mode 100644 index 000000000..bbdc750cf --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu @@ -0,0 +1,83 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq1_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0; + const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu new file mode 100644 index 000000000..5c2d77e77 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu @@ -0,0 +1,85 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.05f; + const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); From 12ae77b8cd48c52298fa26a812ba5c2b05a92eb5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 18:46:04 +0300 Subject: [PATCH 22/23] mmq_id: add iq3_kt, iq4_kt --- ggml/src/ggml-cuda/mmq_id.cu | 8 ++ ggml/src/ggml-cuda/mmq_id_common.cuh | 8 ++ .../mmq-instance-iq3_kt_id.cu | 91 +++++++++++++++++++ .../mmq-instance-iq4_kt_id.cu | 86 ++++++++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 0cb3b5aa5..230715c0e 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -261,6 +261,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, case GGML_TYPE_IQ2_KT: mul_mat_q_case_id(ctx, args, stream); break; + case GGML_TYPE_IQ3_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KT: + mul_mat_q_case_id(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -507,6 +513,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 3329afa74..89baa31b6 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -102,6 +102,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -412,6 +414,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } @@ -470,6 +474,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } @@ -4171,5 +4177,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu new file mode 100644 index 000000000..6af8b9d64 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu @@ -0,0 +1,91 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + const auto qh = (const uint32_t *)bxi->qh; + uint32_t mask = 0x01010101 << ib32; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + auto signs = __vcmpne4(qh[2*j+0] & mask, 0); + v.x = __vsub4(v.x ^ signs, signs); + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v.y = __vsub4(v.y ^ signs, signs); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.01f; + const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu new file mode 100644 index 000000000..a27fe8a8d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu @@ -0,0 +1,86 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto shb = bxi->qs; + const auto ql = (const uint8_t *)(shb + 8); + const auto qh = ql + 64; + const uint32_t sh = shb[ib32] >> (8 + 6*j); + uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12); + uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val1 *= ka; + val2 *= ka; + v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; + v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0; + const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); From c411d443eebf9e81552006d04ef0f3d92af16a0e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 26 Aug 2025 08:43:59 +0300 Subject: [PATCH 23/23] Add CUDA fp8 header --- ggml/src/ggml-cuda/vendors/cuda.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 840809a15..30ddbb3c2 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#if CUDART_VERSION >= 12050 +#include +#endif // CUDART_VERSION >= 12050 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH