Skip to content

Commit 11c866a

Browse files
committed
cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (#10976)
1 parent ee02ad0 commit 11c866a

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cpy.cuh"
2+
#include "dequantize.cuh"
23

34
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
45

@@ -82,13 +83,25 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8283
}
8384

8485
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
85-
const block_q8_0 * xi = (const block_q8_0 *) cxi;
86-
float * dsti = (float *) cdsti;
86+
float* cdstf = (float*)(cdsti);
87+
88+
for (int j = 0; j < QK8_0; j+=2) {
89+
float2 dq;
90+
dequantize_q8_0(cxi, 0, j, dq);
91+
*(cdstf + j) = dq.x;
92+
*(cdstf + j + 1) = dq.y;
93+
}
94+
}
8795

88-
const float d = (float)xi->d;
96+
template<dequantize_kernel_t dequant, int qk>
97+
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
98+
float* cdstf = (float*)(cdsti);
8999

90-
for (int j = 0; j < QK8_0; j++) {
91-
dsti[j] = xi->qs[j] * d;
100+
for (int j = 0; j < qk/2; j++) {
101+
float2 dq;
102+
dequant(cxi, 0, j, dq);
103+
*(cdstf + j) = dq.x;
104+
*(cdstf + j + qk/2) = dq.y;
92105
}
93106
}
94107

@@ -225,7 +238,6 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225238
memcpy(dsti->qh, &qh, sizeof(qh));
226239
}
227240

228-
229241
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
230242
if (x <= val[0]) return 0;
231243
if (x >= val[n-1]) return n-1;
@@ -420,6 +432,58 @@ static void ggml_cpy_f32_q5_1_cuda(
420432
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421433
}
422434

435+
static void ggml_cpy_q5_1_f32_cuda(
436+
const char * cx, char * cdst, const int ne,
437+
const int ne00, const int ne01, const int ne02,
438+
const int nb00, const int nb01, const int nb02,
439+
const int nb03, const int ne10, const int ne11, const int ne12,
440+
const int nb10, const int nb11, const int nb12, const int nb13,
441+
cudaStream_t stream) {
442+
const int num_blocks = ne;
443+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
444+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
445+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
446+
}
447+
448+
static void ggml_cpy_q5_0_f32_cuda(
449+
const char * cx, char * cdst, const int ne,
450+
const int ne00, const int ne01, const int ne02,
451+
const int nb00, const int nb01, const int nb02,
452+
const int nb03, const int ne10, const int ne11, const int ne12,
453+
const int nb10, const int nb11, const int nb12, const int nb13,
454+
cudaStream_t stream) {
455+
const int num_blocks = ne;
456+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
457+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
458+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
459+
}
460+
461+
static void ggml_cpy_q4_1_f32_cuda(
462+
const char * cx, char * cdst, const int ne,
463+
const int ne00, const int ne01, const int ne02,
464+
const int nb00, const int nb01, const int nb02,
465+
const int nb03, const int ne10, const int ne11, const int ne12,
466+
const int nb10, const int nb11, const int nb12, const int nb13,
467+
cudaStream_t stream) {
468+
const int num_blocks = ne;
469+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
470+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
471+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
472+
}
473+
474+
static void ggml_cpy_q4_0_f32_cuda(
475+
const char * cx, char * cdst, const int ne,
476+
const int ne00, const int ne01, const int ne02,
477+
const int nb00, const int nb01, const int nb02,
478+
const int nb03, const int ne10, const int ne11, const int ne12,
479+
const int nb10, const int nb11, const int nb12, const int nb13,
480+
cudaStream_t stream) {
481+
const int num_blocks = ne;
482+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
483+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
484+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
485+
}
486+
423487
static void ggml_cpy_f32_iq4_nl_cuda(
424488
const char * cx, char * cdst, const int ne,
425489
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +552,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488552
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489553
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490554
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
555+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
556+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
557+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491558
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492559
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
560+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
561+
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
562+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493563
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494564
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
565+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
566+
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
567+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495568
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496569
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497570
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498571
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
572+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
573+
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499574
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500575
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501576
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +599,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524599
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525600
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526601
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
602+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
603+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
527604
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528605
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
606+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
607+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
529608
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530609
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
610+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
611+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
531612
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532613
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533614
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534615
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
616+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
617+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
535618
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536619
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
537620
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,15 +3073,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30733073
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
30743074
return true;
30753075
}
3076+
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3077+
return true;
3078+
}
30763079
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
30773080
return true;
30783081
}
3082+
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3083+
return true;
3084+
}
30793085
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
30803086
return true;
30813087
}
3088+
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3089+
return true;
3090+
}
30823091
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
30833092
return true;
30843093
}
3094+
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3095+
return true;
3096+
}
30853097
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
30863098
return true;
30873099
}

0 commit comments

Comments
 (0)