|
1 | 1 | #include "cpy.cuh" |
| 2 | +#include "dequantize.cuh" |
2 | 3 |
|
3 | 4 | typedef void (*cpy_kernel_t)(const char * cx, char * cdst); |
4 | 5 |
|
@@ -82,13 +83,25 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { |
82 | 83 | } |
83 | 84 |
|
84 | 85 | static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { |
85 | | - const block_q8_0 * xi = (const block_q8_0 *) cxi; |
86 | | - float * dsti = (float *) cdsti; |
| 86 | + float* cdstf = (float*)(cdsti); |
| 87 | + |
| 88 | + for (int j = 0; j < QK8_0; j+=2) { |
| 89 | + float2 dq; |
| 90 | + dequantize_q8_0(cxi, 0, j, dq); |
| 91 | + *(cdstf + j) = dq.x; |
| 92 | + *(cdstf + j + 1) = dq.y; |
| 93 | + } |
| 94 | +} |
87 | 95 |
|
88 | | - const float d = (float)xi->d; |
| 96 | +template<dequantize_kernel_t dequant, int qk> |
| 97 | +static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { |
| 98 | + float* cdstf = (float*)(cdsti); |
89 | 99 |
|
90 | | - for (int j = 0; j < QK8_0; j++) { |
91 | | - dsti[j] = xi->qs[j] * d; |
| 100 | + for (int j = 0; j < qk/2; j++) { |
| 101 | + float2 dq; |
| 102 | + dequant(cxi, 0, j, dq); |
| 103 | + *(cdstf + j) = dq.x; |
| 104 | + *(cdstf + j + qk/2) = dq.y; |
92 | 105 | } |
93 | 106 | } |
94 | 107 |
|
@@ -225,7 +238,6 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { |
225 | 238 | memcpy(dsti->qh, &qh, sizeof(qh)); |
226 | 239 | } |
227 | 240 |
|
228 | | - |
229 | 241 | static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { |
230 | 242 | if (x <= val[0]) return 0; |
231 | 243 | if (x >= val[n-1]) return n-1; |
@@ -420,6 +432,58 @@ static void ggml_cpy_f32_q5_1_cuda( |
420 | 432 | (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); |
421 | 433 | } |
422 | 434 |
|
| 435 | +static void ggml_cpy_q5_1_f32_cuda( |
| 436 | + const char * cx, char * cdst, const int ne, |
| 437 | + const int ne00, const int ne01, const int ne02, |
| 438 | + const int nb00, const int nb01, const int nb02, |
| 439 | + const int nb03, const int ne10, const int ne11, const int ne12, |
| 440 | + const int nb10, const int nb11, const int nb12, const int nb13, |
| 441 | + cudaStream_t stream) { |
| 442 | + const int num_blocks = ne; |
| 443 | + cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( |
| 444 | + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, |
| 445 | + ne10, ne11, ne12, nb10, nb11, nb12, nb13); |
| 446 | +} |
| 447 | + |
| 448 | +static void ggml_cpy_q5_0_f32_cuda( |
| 449 | + const char * cx, char * cdst, const int ne, |
| 450 | + const int ne00, const int ne01, const int ne02, |
| 451 | + const int nb00, const int nb01, const int nb02, |
| 452 | + const int nb03, const int ne10, const int ne11, const int ne12, |
| 453 | + const int nb10, const int nb11, const int nb12, const int nb13, |
| 454 | + cudaStream_t stream) { |
| 455 | + const int num_blocks = ne; |
| 456 | + cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( |
| 457 | + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, |
| 458 | + ne10, ne11, ne12, nb10, nb11, nb12, nb13); |
| 459 | +} |
| 460 | + |
| 461 | +static void ggml_cpy_q4_1_f32_cuda( |
| 462 | + const char * cx, char * cdst, const int ne, |
| 463 | + const int ne00, const int ne01, const int ne02, |
| 464 | + const int nb00, const int nb01, const int nb02, |
| 465 | + const int nb03, const int ne10, const int ne11, const int ne12, |
| 466 | + const int nb10, const int nb11, const int nb12, const int nb13, |
| 467 | + cudaStream_t stream) { |
| 468 | + const int num_blocks = ne; |
| 469 | + cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( |
| 470 | + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, |
| 471 | + ne10, ne11, ne12, nb10, nb11, nb12, nb13); |
| 472 | +} |
| 473 | + |
| 474 | +static void ggml_cpy_q4_0_f32_cuda( |
| 475 | + const char * cx, char * cdst, const int ne, |
| 476 | + const int ne00, const int ne01, const int ne02, |
| 477 | + const int nb00, const int nb01, const int nb02, |
| 478 | + const int nb03, const int ne10, const int ne11, const int ne12, |
| 479 | + const int nb10, const int nb11, const int nb12, const int nb13, |
| 480 | + cudaStream_t stream) { |
| 481 | + const int num_blocks = ne; |
| 482 | + cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( |
| 483 | + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, |
| 484 | + ne10, ne11, ne12, nb10, nb11, nb12, nb13); |
| 485 | +} |
| 486 | + |
423 | 487 | static void ggml_cpy_f32_iq4_nl_cuda( |
424 | 488 | const char * cx, char * cdst, const int ne, |
425 | 489 | const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, |
@@ -488,14 +552,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg |
488 | 552 | ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
489 | 553 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { |
490 | 554 | ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
| 555 | + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { |
| 556 | + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, |
| 557 | + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
491 | 558 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { |
492 | 559 | ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
| 560 | + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { |
| 561 | + ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, |
| 562 | + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
493 | 563 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { |
494 | 564 | ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
| 565 | + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { |
| 566 | + ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, |
| 567 | + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
495 | 568 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { |
496 | 569 | ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
497 | 570 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { |
498 | 571 | ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
| 572 | + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { |
| 573 | + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
499 | 574 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { |
500 | 575 | ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); |
501 | 576 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { |
@@ -524,14 +599,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { |
524 | 599 | return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>; |
525 | 600 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { |
526 | 601 | return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>; |
| 602 | + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { |
| 603 | + return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>; |
527 | 604 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { |
528 | 605 | return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>; |
| 606 | + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { |
| 607 | + return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>; |
529 | 608 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { |
530 | 609 | return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>; |
| 610 | + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { |
| 611 | + return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>; |
531 | 612 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { |
532 | 613 | return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>; |
533 | 614 | } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { |
534 | 615 | return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>; |
| 616 | + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { |
| 617 | + return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>; |
535 | 618 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { |
536 | 619 | return (void*) cpy_f32_f16<cpy_1_f32_f16>; |
537 | 620 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { |
|
0 commit comments