@@ -225,6 +225,91 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225225 memcpy (dsti->qh , &qh, sizeof (qh));
226226}
227227
228+ static __device__ void cpy_blck_q5_0_f32 (const char * cxi, char * cdsti) {
229+ const block_q5_0 * xi = (const block_q5_0 *) cxi;
230+ float * dst = (float *) cdsti;
231+ float d = xi->d ; // scale factor (computed as vmax / -16)
232+ const float shift = 16 .0f ;
233+
234+ // Safely copy the 32-bit qh value to avoid misaligned access.
235+ unsigned int qh;
236+ memcpy (&qh, xi->qh , sizeof (qh));
237+
238+ // First half: lower nibble stores element j.
239+ for (int j = 0 ; j < QK5_0/2 ; j++) {
240+ uint8_t lower = xi->qs [j] & 0xF ;
241+ uint8_t high = (qh >> j) & 1 ;
242+ uint8_t q = (high << 4 ) | lower;
243+ dst[j] = ((float )q - shift) * d;
244+ }
245+ // Second half: upper nibble stores element j + QK5_0/2.
246+ for (int j = QK5_0/2 ; j < QK5_0; j++) {
247+ int k = j - QK5_0/2 ;
248+ uint8_t lower = (xi->qs [k] >> 4 ) & 0xF ;
249+ uint8_t high = (qh >> j) & 1 ;
250+ uint8_t q = (high << 4 ) | lower;
251+ dst[j] = ((float )q - shift) * d;
252+ }
253+ }
254+
255+ static __device__ void cpy_blck_q5_1_f32 (const char * cxi, char * cdsti) {
256+ const block_q5_1 * xi = (const block_q5_1 *) cxi;
257+ float * dst = (float *) cdsti;
258+ float d = xi->dm .x ; // scale
259+ float min_val = xi->dm .y ; // minimum value
260+
261+ // Safely copy the 32-bit qh value to avoid misaligned access.
262+ unsigned int qh;
263+ memcpy (&qh, xi->qh , sizeof (qh));
264+
265+ // Decode first half: lower nibble of xi->qs[j] holds element j.
266+ for (int j = 0 ; j < QK5_1/2 ; j++) {
267+ uint8_t lower = xi->qs [j] & 0xF ;
268+ uint8_t high = (qh >> j) & 1 ;
269+ uint8_t q = (high << 4 ) | lower;
270+ dst[j] = min_val + d * (float )q;
271+ }
272+ // Decode second half: upper nibble of xi->qs[j] holds element j+QK5_1/2.
273+ for (int j = QK5_1/2 ; j < QK5_1; j++) {
274+ int k = j - QK5_1/2 ;
275+ uint8_t lower = (xi->qs [k] >> 4 ) & 0xF ;
276+ uint8_t high = (qh >> j) & 1 ;
277+ uint8_t q = (high << 4 ) | lower;
278+ dst[j] = min_val + d * (float )q;
279+ }
280+ }
281+
282+ static __device__ void cpy_blck_q4_0_f32 (const char * cxi, char * cdsti) {
283+ const block_q4_0 * xi = (const block_q4_0 *) cxi;
284+ float * dst = (float *) cdsti;
285+ float d = xi->d ;
286+ const float shift = 8 .0f ;
287+
288+ // Each byte packs two 4-bit quantized values.
289+ for (int j = 0 ; j < QK4_0/2 ; j++) {
290+ uint8_t q_val = xi->qs [j];
291+ uint8_t q0 = q_val & 0x0F ;
292+ uint8_t q1 = (q_val >> 4 ) & 0x0F ;
293+ dst[j] = ((float )q0 - shift) * d;
294+ dst[j + QK4_0/2 ] = ((float )q1 - shift) * d;
295+ }
296+ }
297+
298+ static __device__ void cpy_blck_q4_1_f32 (const char * cxi, char * cdsti) {
299+ const block_q4_1 * xi = (const block_q4_1 *) cxi;
300+ float * dst = (float *) cdsti;
301+ const float d = xi->dm .x ;
302+ const float vmin = xi->dm .y ;
303+
304+ // Each byte packs two 4-bit quantized values.
305+ for (int j = 0 ; j < QK4_1/2 ; ++j) {
306+ uint8_t byte_val = xi->qs [j];
307+ uint8_t q0 = byte_val & 0x0F ;
308+ uint8_t q1 = (byte_val >> 4 ) & 0x0F ;
309+ dst[j] = vmin + d * (float )q0;
310+ dst[j + QK4_1/2 ] = vmin + d * (float )q1;
311+ }
312+ }
228313
229314static __device__ __forceinline__ int best_index_int8 (int n, const int8_t * val, float x) {
230315 if (x <= val[0 ]) return 0 ;
@@ -420,6 +505,58 @@ static void ggml_cpy_f32_q5_1_cuda(
420505 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421506}
422507
508+ static void ggml_cpy_q5_1_f32_cuda (
509+ const char * cx, char * cdst, const int ne,
510+ const int ne00, const int ne01, const int ne02,
511+ const int nb00, const int nb01, const int nb02,
512+ const int nb03, const int ne10, const int ne11, const int ne12,
513+ const int nb10, const int nb11, const int nb12, const int nb13,
514+ cudaStream_t stream) {
515+ const int num_blocks = ne;
516+ cpy_q_f32<cpy_blck_q5_1_f32, QK5_1><<<num_blocks, 1 , 0 , stream>>> (
517+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
518+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
519+ }
520+
521+ static void ggml_cpy_q5_0_f32_cuda (
522+ const char * cx, char * cdst, const int ne,
523+ const int ne00, const int ne01, const int ne02,
524+ const int nb00, const int nb01, const int nb02,
525+ const int nb03, const int ne10, const int ne11, const int ne12,
526+ const int nb10, const int nb11, const int nb12, const int nb13,
527+ cudaStream_t stream) {
528+ const int num_blocks = ne;
529+ cpy_q_f32<cpy_blck_q5_0_f32, QK5_0><<<num_blocks, 1 , 0 , stream>>> (
530+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
531+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
532+ }
533+
534+ static void ggml_cpy_q4_1_f32_cuda (
535+ const char * cx, char * cdst, const int ne,
536+ const int ne00, const int ne01, const int ne02,
537+ const int nb00, const int nb01, const int nb02,
538+ const int nb03, const int ne10, const int ne11, const int ne12,
539+ const int nb10, const int nb11, const int nb12, const int nb13,
540+ cudaStream_t stream) {
541+ const int num_blocks = ne;
542+ cpy_q_f32<cpy_blck_q4_1_f32, QK4_1><<<num_blocks, 1 , 0 , stream>>>
543+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
544+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
545+ }
546+
547+ static void ggml_cpy_q4_0_f32_cuda (
548+ const char * cx, char * cdst, const int ne,
549+ const int ne00, const int ne01, const int ne02,
550+ const int nb00, const int nb01, const int nb02,
551+ const int nb03, const int ne10, const int ne11, const int ne12,
552+ const int nb10, const int nb11, const int nb12, const int nb13,
553+ cudaStream_t stream) {
554+ const int num_blocks = ne;
555+ cpy_q_f32<cpy_blck_q4_0_f32, QK4_0><<<num_blocks, 1 , 0 , stream>>>
556+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
557+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
558+ }
559+
423560static void ggml_cpy_f32_iq4_nl_cuda (
424561 const char * cx, char * cdst, const int ne,
425562 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +625,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488625 ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489626 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490627 ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
628+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
629+ ggml_cpy_q4_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
630+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491631 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492632 ggml_cpy_f32_q4_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
633+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
634+ ggml_cpy_q4_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
635+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493636 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494637 ggml_cpy_f32_q5_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
638+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
639+ ggml_cpy_q5_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
640+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495641 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496642 ggml_cpy_f32_iq4_nl_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497643 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498644 ggml_cpy_f32_q5_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
645+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
646+ ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499647 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500648 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501649 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +672,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524672 return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525673 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526674 return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
675+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
676+ return (void *) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
527677 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528678 return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
679+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
680+ return (void *) cpy_q_f32<cpy_blck_q4_1_f32, QK4_1>;
529681 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530682 return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
683+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
684+ return (void *) cpy_q_f32<cpy_blck_q5_0_f32, QK5_0>;
531685 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532686 return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533687 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534688 return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
689+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
690+ return (void *) cpy_q_f32<cpy_blck_q5_1_f32, QK5_1>;
535691 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536692 return (void *) cpy_f32_f16<cpy_1_f32_f16>;
537693 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
0 commit comments