Skip to content

Commit dfed296

Browse files
committed
Reapply "CUDA: use switch statements in constexpr functions (ggml-org#13095)"
1 parent c84561e commit dfed296

File tree

2 files changed

+92
-84
lines changed

2 files changed

+92
-84
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,29 @@ static constexpr __device__ int get_mmq_y_device() {
159159
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
160160

161161
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
162-
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
163-
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
164-
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
165-
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
166-
type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
167-
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
168-
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
169-
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
170-
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
171-
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
172-
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
173-
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
174-
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
175-
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
176-
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
177-
type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
178-
type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
179-
type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
180-
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
181-
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
182-
tile_x_sizes{0, 0, 0};
162+
switch (type) {
163+
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
164+
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
165+
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
166+
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
167+
case GGML_TYPE_Q6_0: return MMQ_DP4A_TXS_Q8_0;
168+
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
169+
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
170+
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
171+
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
172+
case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
173+
case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
174+
case GGML_TYPE_TQ2_0: return MMQ_DP4A_TXS_Q8_0;
175+
case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
176+
case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
177+
case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
178+
case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
179+
case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
180+
case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
181+
case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
182+
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
183+
default: return tile_x_sizes{0, 0, 0};
184+
}
183185
}
184186

185187
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
@@ -195,27 +197,29 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
195197
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
196198

197199
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
198-
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
199-
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
200-
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
201-
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
202-
type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
203-
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
204-
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
205-
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
206-
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
207-
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
208-
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
209-
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
210-
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
211-
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
212-
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
213-
type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
214-
type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
215-
type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
216-
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
217-
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
218-
0;
200+
switch (type) {
201+
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
202+
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
203+
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
204+
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
205+
case GGML_TYPE_Q6_0: return MMQ_MMA_TILE_X_K_Q8_0;
206+
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
207+
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
208+
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
209+
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
210+
case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
211+
case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
212+
case GGML_TYPE_TQ2_0: return MMQ_MMA_TILE_X_K_Q8_0;
213+
case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
214+
case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
215+
case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
216+
case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
217+
case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
218+
case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
219+
case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
220+
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
221+
default: return 0;
222+
}
219223
}
220224

221225
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,55 @@
55
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
66

77
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
8-
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
9-
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
10-
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
11-
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
12-
type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 :
13-
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
14-
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
15-
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
16-
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
17-
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
18-
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
19-
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
20-
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
21-
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
22-
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
23-
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
24-
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
25-
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
26-
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
27-
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
28-
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
29-
nullptr;
8+
switch (type) {
9+
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
10+
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
11+
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
12+
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
13+
case GGML_TYPE_Q6_0: return vec_dot_q6_0_q8_1;
14+
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
15+
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
16+
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
17+
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
18+
case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
19+
case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
20+
case GGML_TYPE_TQ2_0: return vec_dot_tq2_0_q8_1;
21+
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
22+
case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
23+
case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
24+
case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
25+
case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
26+
case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
27+
case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
28+
case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
29+
case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
30+
default: return nullptr;
31+
}
3032
}
3133

3234
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
33-
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
34-
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
35-
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
36-
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
37-
type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ :
38-
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
39-
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
40-
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
41-
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
42-
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
43-
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
44-
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
45-
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
46-
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
47-
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
48-
type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
49-
type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
50-
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
51-
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
52-
1;
35+
switch (type) {
36+
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
37+
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
38+
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
39+
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
40+
case GGML_TYPE_Q6_0: return VDR_Q6_0_Q8_1_MMVQ;
41+
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
42+
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
43+
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
44+
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
45+
case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
46+
case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
47+
case GGML_TYPE_TQ2_0: return VDR_TQ2_0_Q8_1_MMVQ;
48+
case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
49+
case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
50+
case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
51+
case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
52+
case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
53+
case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
54+
case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
55+
default: return 1;
56+
}
5357
}
5458

5559
enum mmvq_parameter_table_id {

0 commit comments

Comments
 (0)