@@ -88,6 +88,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8888 }
8989}
9090
91+ static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
92+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
93+ float * dsti = (float *) cdsti;
94+
95+ const float d = (float )xi->d ;
96+
97+ for (int j = 0 ; j < QK8_0; j++) {
98+ dsti[j] = xi->qs [j] * d;
99+ }
100+ }
101+
91102static __device__ void cpy_blck_f32_q4_0 (const char * cxi, char * cdsti) {
92103 const float * xi = (const float *) cxi;
93104 block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -337,6 +348,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
337348 cpy_blck (cx + x_offset, cdst + dst_offset);
338349}
339350
351+ template <cpy_kernel_t cpy_blck, int qk>
352+ static __global__ void cpy_q_f32 (const char * cx, char * cdst, const int ne,
353+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
354+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
355+ const int nb12, const int nb13) {
356+ const int i = (blockDim .x *blockIdx .x + threadIdx .x )*qk;
357+
358+ if (i >= ne) {
359+ return ;
360+ }
361+
362+ const int i03 = i/(ne00 * ne01 * ne02);
363+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
364+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
365+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
366+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
367+
368+ const int i13 = i/(ne10 * ne11 * ne12);
369+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
370+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
371+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
372+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
373+
374+ cpy_blck (cx + x_offset, cdst + dst_offset);
375+ }
376+
340377static void ggml_cpy_f16_f32_cuda (
341378 const char * cx, char * cdst, const int ne,
342379 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -388,6 +425,16 @@ static void ggml_cpy_f32_q8_0_cuda(
388425 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
389426}
390427
428+ static void ggml_cpy_q8_0_f32_cuda (
429+ const char * cx, char * cdst, const int ne,
430+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
431+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
432+
433+ const int num_blocks = ne;
434+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1 , 0 , stream>>>
435+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
436+ }
437+
391438static void ggml_cpy_f32_q4_0_cuda (
392439 const char * cx, char * cdst, const int ne,
393440 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -509,6 +556,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
509556 ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
510557 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
511558 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
559+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
560+ ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
512561 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
513562 ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
514563 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -547,6 +596,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
547596 return (void *) cpy_f32_f16<cpy_1_f32_bf16>;
548597 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
549598 return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
599+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
600+ return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
550601 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
551602 return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
552603 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
0 commit comments