Skip to content

Commit b1e19d4

Browse files
committed
Revert "CUDA: use switch statements in constexpr functions (ggml-org#13095)"
This reverts commit b10d8bf.
1 parent 74c0e68 commit b1e19d4

File tree

2 files changed

+76
-84
lines changed

2 files changed

+76
-84
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -156,27 +156,25 @@ static constexpr __device__ int get_mmq_y_device() {
156156
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
157157

158158
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
159-
switch (type) {
160-
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
161-
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
162-
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
163-
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
164-
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
165-
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
166-
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
167-
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
168-
case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
169-
case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
170-
case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
171-
case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
172-
case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
173-
case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
174-
case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
175-
case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
176-
case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
177-
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
178-
default: return tile_x_sizes{0, 0, 0};
179-
}
159+
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
160+
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
161+
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
162+
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
163+
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
164+
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
165+
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
166+
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
167+
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
168+
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
169+
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
170+
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
171+
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
172+
type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
173+
type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
174+
type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
175+
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
176+
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
177+
tile_x_sizes{0, 0, 0};
180178
}
181179

182180
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
@@ -192,27 +190,25 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
192190
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
193191

194192
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
195-
switch (type) {
196-
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
197-
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
198-
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
199-
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
200-
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
201-
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
202-
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
203-
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
204-
case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
205-
case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
206-
case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
207-
case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
208-
case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
209-
case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
210-
case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
211-
case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
212-
case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
213-
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
214-
default: return 0;
215-
}
193+
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
194+
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
195+
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
196+
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
197+
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
198+
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
199+
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
200+
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201+
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
202+
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
203+
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
204+
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
205+
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
206+
type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
207+
type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
208+
type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
209+
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
210+
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
211+
0;
216212
}
217213

218214
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,47 @@
44
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
55

66
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
7-
switch (type) {
8-
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
9-
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
10-
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
11-
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
12-
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
13-
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
14-
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
15-
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
16-
case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
17-
case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
18-
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
19-
case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
20-
case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
21-
case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
22-
case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
23-
case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
24-
case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
25-
case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
26-
case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
27-
default: return nullptr;
28-
}
7+
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
8+
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
9+
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
10+
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
11+
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
12+
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
13+
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
14+
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
15+
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
16+
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
17+
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
18+
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
19+
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
20+
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
21+
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
22+
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
23+
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
24+
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
25+
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
26+
nullptr;
2927
}
3028

3129
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
32-
switch (type) {
33-
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
34-
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
35-
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
36-
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
37-
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
38-
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
39-
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
40-
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
41-
case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
42-
case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
43-
case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
44-
case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
45-
case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
46-
case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
47-
case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
48-
case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
49-
case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
50-
default: return 1;
51-
}
30+
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
31+
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
32+
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
33+
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
34+
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
35+
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
36+
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
37+
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
38+
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
39+
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
40+
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
41+
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
42+
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
43+
type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
44+
type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
45+
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
46+
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
47+
1;
5248
}
5349

5450
enum mmvq_parameter_table_id {

0 commit comments

Comments
 (0)