11#include  " cpy.hpp" 
22
33#include  < float.h> 
4+ #include  < string> 
45
56#include  " dequantize.hpp" 
7+ #include  " ggml-sycl/common.hpp" 
8+ #include  " ggml-sycl/presets.hpp" 
9+ #include  " ggml.h" 
610
711static  __dpct_inline__ int  best_index_int8 (int  n, const  int8_t  * val, float  x) {
812    if  (x <= val[0 ]) {
@@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
116120    }
117121}
118122
123+ /*  quantized type same copy */ 
124+ template <typename  T>
125+ static  void  cpy_blck_q_q (const  char  * cxi, char  * cdsti) {
126+     const  T * xi = (const  T *) cxi;
127+     T * dsti = (T *) cdsti;
128+     *dsti = *xi;
129+ }
130+ 
131+ 
119132static  void  cpy_blck_q8_0_f32 (const  char  * cxi, char  * cdsti) {
120133    float  * cdstf = (float  *) (cdsti);
121134
@@ -311,6 +324,34 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
311324    }
312325}
313326
327+ 
328+ template  <typename  T, int  qk>
329+ static  void  cpy_q_q (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01, const  int  ne02,
330+                       const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03, const  int  ne10, const  int  ne11,
331+                       const  int  ne12, const  int  nb10, const  int  nb11, const  int  nb12, const  int  nb13,
332+                       const  sycl::nd_item<3 > & item_ct1) {
333+     const  int  i = (item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 )) * qk;
334+ 
335+     if  (i >= ne) {
336+         return ;
337+     }
338+ 
339+     const  int  i03      = i / (ne00 * ne01 * ne02);
340+     const  int  i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
341+     const  int  i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
342+     const  int  i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
343+     const  int  x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
344+ 
345+ 
346+     const  int  i13        = i / (ne10 * ne11 * ne12);
347+     const  int  i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
348+     const  int  i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
349+     const  int  i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
350+     const  int  dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
351+ 
352+     cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
353+ }
354+ 
314355template  <cpy_kernel_t  cpy_blck, int  qk>
315356static  void  cpy_f32_q (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01, const  int  ne02,
316357                      const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03, const  int  ne10, const  int  ne11,
@@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
322363        return ;
323364    }
324365
366+ 
325367    const  int  i03      = i / (ne00 * ne01 * ne02);
326368    const  int  i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
327369    const  int  i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
@@ -615,6 +657,70 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
615657    }
616658}
617659
660+ static  void  ggml_cpy_q8_0_q8_0 (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01,
661+                                    const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03,
662+                                    const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11,
663+                                    const  int  nb12, const  int  nb13, queue_ptr stream) {
664+     const  int  num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
665+     stream->parallel_for (
666+         sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
667+                               sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
668+             cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
669+         });
670+ }
671+ 
672+ 
673+ static  void  ggml_cpy_q5_0_q5_0 (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01,
674+                                    const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03,
675+                                    const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11,
676+                                    const  int  nb12, const  int  nb13, queue_ptr stream) {
677+     const  int  num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
678+     stream->parallel_for (
679+         sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
680+                               sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
681+             cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
682+         });
683+ }
684+ 
685+ 
686+ static  void  ggml_cpy_q5_1_q5_1 (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01,
687+                                    const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03,
688+                                    const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11,
689+                                    const  int  nb12, const  int  nb13, queue_ptr stream) {
690+     const  int  num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
691+ 
692+     stream->parallel_for (
693+         sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
694+                               sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
695+             cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
696+         });
697+ }
698+ 
699+ 
700+ static  void  ggml_cpy_q4_0_q4_0 (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01,
701+                                    const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03,
702+                                    const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11,
703+                                    const  int  nb12, const  int  nb13, queue_ptr stream) {
704+     const  int  num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
705+     stream->parallel_for (
706+         sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE), sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
707+             cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
708+         });
709+ }
710+ 
711+ 
712+ static  void  ggml_cpy_q4_1_q4_1 (const  char  * cx, char  * cdst, const  int  ne, const  int  ne00, const  int  ne01,
713+                                    const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02, const  int  nb03,
714+                                    const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11,
715+                                    const  int  nb12, const  int  nb13, queue_ptr stream) {
716+ 
717+    const  int  num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
718+    stream->parallel_for (
719+         sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE), sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
720+             cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
721+         });
722+ }
723+ 
618724void  ggml_sycl_cpy (ggml_backend_sycl_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1) try {
619725    //  Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
620726    scope_op_debug_print scope_dbg_print (__func__, src1, /* num_src=*/ 0 ,
@@ -632,8 +738,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
632738
633739    char  * src0_ddc = (char  *) src0->data ;
634740    char  * src1_ddc = (char  *) src1->data ;
635- 
636-     if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
741+     if  ((src0->type  == src1->type ) && (ggml_is_contiguous (src0) && ggml_is_contiguous (src1))) {
742+         GGML_SYCL_DEBUG (" %s: memcpy path\n " 
743+         main_stream->memcpy (src1_ddc, src0_ddc, ggml_nbytes (src0));
744+     } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
637745        ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
638746                              nb11, nb12, nb13, main_stream);
639747    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
@@ -684,6 +792,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
684792    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_IQ4_NL) {
685793        ggml_cpy_f32_iq4_nl_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
686794                                 nb10, nb11, nb12, nb13, main_stream);
795+     } else  if  (src0->type  == GGML_TYPE_Q8_0 && src1->type  == GGML_TYPE_Q8_0) {
796+         ggml_cpy_q8_0_q8_0 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
797+     } else  if  (src0->type  == GGML_TYPE_Q5_0 && src1->type  == GGML_TYPE_Q5_0) {
798+         ggml_cpy_q5_0_q5_0 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
799+     } else  if  (src0->type  == GGML_TYPE_Q5_1 && src1->type  == GGML_TYPE_Q5_1) {
800+         ggml_cpy_q5_1_q5_1 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
801+     } else  if  (src0->type  == GGML_TYPE_Q4_0 && src1->type  == GGML_TYPE_Q4_0) {
802+         ggml_cpy_q4_0_q4_0 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
803+     } else  if  (src0->type  == GGML_TYPE_Q4_1 && src1->type  == GGML_TYPE_Q4_1) {
804+         ggml_cpy_q4_1_q4_1 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
687805    } else  {
688806        GGML_LOG_ERROR (" %s: unsupported type combination (%s to %s)\n " ggml_type_name (src0->type ),
689807                       ggml_type_name (src1->type ));
0 commit comments