11#include " cpy.cuh"
2+ #include " dequantize.cuh"
23
34typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
45
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8283}
8384
8485static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
85- const block_q8_0 * xi = (const block_q8_0 *) cxi;
86- float * dsti = (float *) cdsti;
87-
88- const float d = (float )xi->d ;
89-
90- for (int j = 0 ; j < QK8_0; j++) {
91- dsti[j] = xi->qs [j] * d;
86+ float * cdstf = (float *)(cdsti);
87+
88+ #pragma unroll
89+ for (int j = 0 ; j < QK8_0; j += 2 ) {
90+ dfloat2 dq;
91+ dequantize_q8_0 (cxi, 0 , j, dq);
92+ *(cdstf + j) = dq.x ;
93+ *(cdstf + j + 1 ) = dq.y ;
9294 }
9395}
9496
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225227 memcpy (dsti->qh , &qh, sizeof (qh));
226228}
227229
230+ template <dequantize_kernel_t dequant, int qk>
231+ static __device__ void cpy_blck_q_f32 (const char * cxi, char * cdsti) {
232+ float * cdstf = (float *)(cdsti);
233+
234+ #pragma unroll
235+ for (int j = 0 ; j < qk/2 ; j++) {
236+ dfloat2 dq;
237+ dequant (cxi, 0 , j, dq);
238+ *(cdstf + j) = dq.x ;
239+ *(cdstf + j + qk/2 ) = dq.y ;
240+ }
241+ }
228242
229243static __device__ __forceinline__ int best_index_int8 (int n, const int8_t * val, float x) {
230244 if (x <= val[0 ]) return 0 ;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
387401 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
388402}
389403
404+ static void ggml_cpy_q4_0_f32_cuda (
405+ const char * cx, char * cdst, const int ne,
406+ const int ne00, const int ne01, const int ne02,
407+ const int nb00, const int nb01, const int nb02,
408+ const int nb03, const int ne10, const int ne11, const int ne12,
409+ const int nb10, const int nb11, const int nb12, const int nb13,
410+ cudaStream_t stream) {
411+ const int num_blocks = ne;
412+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1 , 0 , stream>>> (
413+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
414+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
415+ }
416+
390417static void ggml_cpy_f32_q4_1_cuda (
391418 const char * cx, char * cdst, const int ne,
392419 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
398425 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
399426}
400427
428+ static void ggml_cpy_q4_1_f32_cuda (
429+ const char * cx, char * cdst, const int ne,
430+ const int ne00, const int ne01, const int ne02,
431+ const int nb00, const int nb01, const int nb02,
432+ const int nb03, const int ne10, const int ne11, const int ne12,
433+ const int nb10, const int nb11, const int nb12, const int nb13,
434+ cudaStream_t stream) {
435+ const int num_blocks = ne;
436+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1 , 0 , stream>>> (
437+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
438+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
439+ }
440+
401441static void ggml_cpy_f32_q5_0_cuda (
402442 const char * cx, char * cdst, const int ne,
403443 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
409449 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
410450}
411451
452+ static void ggml_cpy_q5_0_f32_cuda (
453+ const char * cx, char * cdst, const int ne,
454+ const int ne00, const int ne01, const int ne02,
455+ const int nb00, const int nb01, const int nb02,
456+ const int nb03, const int ne10, const int ne11, const int ne12,
457+ const int nb10, const int nb11, const int nb12, const int nb13,
458+ cudaStream_t stream) {
459+ const int num_blocks = ne;
460+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1 , 0 , stream>>> (
461+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
462+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
463+ }
464+
412465static void ggml_cpy_f32_q5_1_cuda (
413466 const char * cx, char * cdst, const int ne,
414467 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
420473 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421474}
422475
476+ static void ggml_cpy_q5_1_f32_cuda (
477+ const char * cx, char * cdst, const int ne,
478+ const int ne00, const int ne01, const int ne02,
479+ const int nb00, const int nb01, const int nb02,
480+ const int nb03, const int ne10, const int ne11, const int ne12,
481+ const int nb10, const int nb11, const int nb12, const int nb13,
482+ cudaStream_t stream) {
483+ const int num_blocks = ne;
484+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1 , 0 , stream>>> (
485+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
487+ }
488+
423489static void ggml_cpy_f32_iq4_nl_cuda (
424490 const char * cx, char * cdst, const int ne,
425491 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488554 ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489555 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490556 ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
558+ ggml_cpy_q4_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
559+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491560 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492561 ggml_cpy_f32_q4_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
562+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
563+ ggml_cpy_q4_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
564+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493565 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494566 ggml_cpy_f32_q5_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
567+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568+ ggml_cpy_q5_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
569+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495570 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496571 ggml_cpy_f32_iq4_nl_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497572 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498573 ggml_cpy_f32_q5_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
574+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
575+ ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499576 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500577 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501578 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524601 return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525602 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526603 return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
604+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
605+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
527606 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528607 return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
608+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
609+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
529610 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530611 return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
612+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
613+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
531614 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532615 return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533616 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534617 return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
618+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
619+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
535620 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536621 return (void *) cpy_f32_f16<cpy_1_f32_f16>;
537622 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
0 commit comments