Skip to content

Commit d1cac3d

Browse files
committed
restructure vector length selection code
1 parent 0b43956 commit d1cac3d

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5108,12 +5108,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
51085108

51095109
#elif defined __riscv_v_intrinsic
51105110

5111+
const int vector_length = __riscv_vlenb() * 8;
51115112
float sumf = 0;
51125113

5113-
if (__riscv_vlenb() >= 32) {
5114-
uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5115-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5114+
uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5115+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5116+
uint8_t atmp[16];
51165117

5118+
switch (vector_length) {
5119+
case 256:
51175120
for (int i = 0; i < nb; ++i) {
51185121
const uint8_t * q2 = x[i].qs;
51195122
const int8_t * q8 = y[i].qs;
@@ -5188,8 +5191,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
51885191

51895192
sumf += dall * isum;
51905193
}
5191-
} else if (__riscv_vlenb() == 16) {
5192-
uint8_t atmp[16];
5194+
break;
5195+
case 128:
51935196
for (int i = 0; i < nb; ++i) {
51945197
const uint8_t * q2 = x[i].qs;
51955198
const int8_t * q8 = y[i].qs;
@@ -5277,6 +5280,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
52775280

52785281
sumf += dall * isum;
52795282
}
5283+
break;
5284+
default:
5285+
assert(false && "Unsupported vector length");
5286+
break;
52805287
}
52815288

52825289
*s = sumf;
@@ -6141,8 +6148,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
61416148
uint32_t aux[3];
61426149
uint32_t utmp[4];
61436150

6151+
const int vector_length = __riscv_vlenb() * 8;
61446152
float sumf = 0;
6145-
if (__riscv_vlenb() >= 32) {
6153+
6154+
switch (vector_length) {
6155+
case 256:
61466156
for (int i = 0; i < nb; ++i) {
61476157

61486158
const uint8_t * GGML_RESTRICT q3 = x[i].qs;
@@ -6234,7 +6244,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
62346244
sumf += d*sum_t;
62356245

62366246
}
6237-
} else if (__riscv_vlenb() == 16) {
6247+
break;
6248+
case 128:
62386249
for (int i = 0; i < nb; ++i) {
62396250
const uint8_t * restrict q3 = x[i].qs;
62406251
const uint8_t * restrict qh = x[i].hmask;
@@ -6348,6 +6359,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
63486359
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
63496360
sumf += d * isum;
63506361
}
6362+
break;
6363+
default:
6364+
assert(false && "Unsupported vector length");
6365+
break;
63516366
}
63526367

63536368
*s = sumf;
@@ -7065,9 +7080,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
70657080
const uint8_t * scales = (const uint8_t*)&utmp[0];
70667081
const uint8_t * mins = (const uint8_t*)&utmp[2];
70677082

7083+
const int vector_length = __riscv_vlenb() * 8;
70687084
float sumf = 0;
70697085

7070-
if (__riscv_vlenb() >= 32) {
7086+
switch (vector_length) {
7087+
case 256:
70717088
for (int i = 0; i < nb; ++i) {
70727089

70737090
size_t vl = 8;
@@ -7130,7 +7147,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
71307147
sumf += d*(sum_1 + sum_2);
71317148

71327149
}
7133-
} else if (__riscv_vlenb() == 16) {
7150+
break;
7151+
case 128:
71347152
for (int i = 0; i < nb; ++i) {
71357153
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
71367154
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
@@ -7233,6 +7251,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
72337251

72347252
sumf += d * sumi;
72357253
}
7254+
break;
7255+
default:
7256+
assert(false && "Unsupported vector length");
7257+
break;
72367258
}
72377259

72387260
*s = sumf;
@@ -8912,9 +8934,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
89128934

89138935
#elif defined __riscv_v_intrinsic
89148936

8937+
const int vector_length = __riscv_vlenb() * 8;
89158938
float sumf = 0;
89168939

8917-
if (__riscv_vlenb() >= 32) {
8940+
switch (vector_length) {
8941+
case 256:
89188942
for (int i = 0; i < nb; ++i) {
89198943

89208944
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
@@ -8994,7 +9018,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
89949018
sumf += d * sum_t;
89959019

89969020
}
8997-
} else if (__riscv_vlenb() == 16) {
9021+
break;
9022+
case 128:
89989023
for (int i = 0; i < nb; ++i) {
89999024

90009025
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
@@ -9067,6 +9092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
90679092
sumf += d * sum_t;
90689093

90699094
}
9095+
break;
9096+
default:
9097+
assert(false && "Unsupported vector length");
9098+
break;
90709099
}
90719100

90729101
*s = sumf;

0 commit comments

Comments
 (0)