Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 43 additions & 40 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,27 @@ static constexpr __device__ int get_mmq_y_device() {
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}

static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
tile_x_sizes{0, 0, 0};
switch (type) {
case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q8_0 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q2_K : return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K : return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K : return MMQ_DP4A_TXS_Q4_K;
case GGML_TYPE_Q5_K : return MMQ_DP4A_TXS_Q5_K;
case GGML_TYPE_Q6_K : return MMQ_DP4A_TXS_Q6_K;
case GGML_TYPE_IQ2_XXS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ2_XS : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ2_S : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
}

#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
Expand All @@ -195,26 +196,28 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");

static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
0;
switch (type) {
case GGML_TYPE_Q4_0 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_1 : return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1 : return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q6_0 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q8_0 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q2_K : return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K : return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_K : return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q6_K : return MMQ_MMA_TILE_X_K_Q6_K;
case GGML_TYPE_IQ2_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ2_XS : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ2_S : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
}

#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
Expand Down
84 changes: 44 additions & 40 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,53 @@
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);

static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
nullptr;
switch (type) {
case GGML_TYPE_Q4_0 : return vec_dot_q4_0_q8_1;
case GGML_TYPE_Q4_1 : return vec_dot_q4_1_q8_1;
case GGML_TYPE_Q5_0 : return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1 : return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q6_0 : return vec_dot_q6_0_q8_1;
case GGML_TYPE_Q8_0 : return vec_dot_q8_0_q8_1;
case GGML_TYPE_Q2_K : return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K : return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K : return vec_dot_q4_K_q8_1;
case GGML_TYPE_Q5_K : return vec_dot_q5_K_q8_1;
case GGML_TYPE_Q6_K : return vec_dot_q6_K_q8_1;
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
case GGML_TYPE_IQ2_XS : return vec_dot_iq2_xs_q8_1;
case GGML_TYPE_IQ2_S : return vec_dot_iq2_s_q8_1;
case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1;
case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1;
case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1;
case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1;
case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1;
default : return nullptr;
}
}

static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
1;
switch (type) {
case GGML_TYPE_Q4_0 : return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1 : return VDR_Q4_1_Q8_1_MMVQ;
case GGML_TYPE_Q5_0 : return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1 : return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q6_0 : return VDR_Q6_0_Q8_1_MMVQ;
case GGML_TYPE_Q8_0 : return VDR_Q8_0_Q8_1_MMVQ;
case GGML_TYPE_Q2_K : return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K : return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K : return VDR_Q4_K_Q8_1_MMVQ;
case GGML_TYPE_Q5_K : return VDR_Q5_K_Q8_1_MMVQ;
case GGML_TYPE_Q6_K : return VDR_Q6_K_Q8_1_MMVQ;
case GGML_TYPE_IQ2_XXS : return VDR_IQ2_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ2_XS : return VDR_IQ2_XS_Q8_1_MMVQ;
case GGML_TYPE_IQ2_S : return VDR_IQ2_S_Q8_1_MMVQ;
case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ;
case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ;
case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ;
default : return 1;
}
}

template <ggml_type type, int ncols_y>
Expand Down