@@ -120,36 +120,13 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
120120}
121121
122122/* quantized type same copy */
123- static void cpy_block_q8_0_q8_0 (const char * cxi, char * cdsti) {
124- const block_q8_0 * xi = (const block_q8_0 *) cxi;
125- block_q8_0 * dsti = (block_q8_0 *) cdsti;
123+ template <typename T>
124+ static void cpy_blck_q_q (const char * cxi, char * cdsti) {
125+ const T * xi = (const T *) cxi;
126+ T * dsti = (T *) cdsti;
126127 *dsti = *xi;
127128}
128129
129- static void cpy_block_q5_0_q5_0 (const char * cxi, char * cdsti) {
130- const block_q5_0 * xi = (const block_q5_0 *) cxi;
131- block_q5_0 * dsti = (block_q5_0 *) cdsti;
132- *dsti = *xi;
133- }
134-
135-
136- static void cpy_block_q5_1_q5_1 (const char * cxi, char * cdsti) {
137- const block_q5_1 * xi = (const block_q5_1 *) cxi;
138- block_q5_1 * dsti = (block_q5_1 *) cdsti;
139- *dsti = *xi;
140- }
141-
142- static void cpy_block_q4_0_q4_0 (const char * cxi, char * cdsti) {
143- const block_q4_0 * xi = (const block_q4_0 *) cxi;
144- block_q4_0 * dsti = (block_q4_0 *) cdsti;
145- *dsti = *xi;
146- }
147-
148- static void cpy_block_q4_1_q4_1 (const char * cxi, char * cdsti) {
149- const block_q4_1 * xi = (const block_q4_1 *) cxi;
150- block_q4_1 * dsti = (block_q4_1 *) cdsti;
151- *dsti = *xi;
152- }
153130
154131static void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
155132 float * cdstf = (float *) (cdsti);
@@ -347,7 +324,7 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
347324}
348325
349326
350- template <cpy_kernel_t cpy_blck , int qk>
327+ template <typename T , int qk>
351328static void cpy_q_q (const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
352329 const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
353330 const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
@@ -371,7 +348,7 @@ static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00,
371348 const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
372349 const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
373350
374- cpy_blck (cx + x_offset, cdst + dst_offset);
351+ cpy_blck_q_q<T> (cx + x_offset, cdst + dst_offset);
375352}
376353
377354template <cpy_kernel_t cpy_blck, int qk>
@@ -687,7 +664,7 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
687664 const int num_blocks = ne;
688665 stream->parallel_for (
689666 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
690- cpy_q_q<cpy_block_q8_0_q8_0 , QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
667+ cpy_q_q<block_q8_0 , QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
691668 });
692669}
693670
@@ -700,7 +677,7 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
700677 const int num_blocks = ne;
701678 stream->parallel_for (
702679 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
703- cpy_q_q<cpy_block_q5_0_q5_0 , QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
680+ cpy_q_q<block_q5_0 , QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
704681 });
705682}
706683
@@ -713,7 +690,7 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
713690 const int num_blocks = ne;
714691 stream->parallel_for (
715692 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
716- cpy_q_q<cpy_block_q5_1_q5_1 , QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
693+ cpy_q_q<block_q5_1 , QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
717694 });
718695}
719696
@@ -726,7 +703,7 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
726703 const int num_blocks = ne;
727704 stream->parallel_for (
728705 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
729- cpy_q_q<cpy_block_q4_0_q4_0 , QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
706+ cpy_q_q<block_q4_0 , QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
730707 });
731708}
732709
@@ -739,7 +716,7 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
739716 const int num_blocks = ne;
740717 stream->parallel_for (
741718 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
742- cpy_q_q<cpy_block_q4_1_q4_1 , QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
719+ cpy_q_q<block_q4_1 , QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
743720 });
744721}
745722
0 commit comments