55
66#include " dequantize.hpp"
77#include " ggml-sycl/common.hpp"
8+ #include " ggml-sycl/presets.hpp"
89#include " ggml.h"
910
1011static __dpct_inline__ int best_index_int8 (int n, const int8_t * val, float x) {
@@ -660,10 +661,10 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
660661 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
661662 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
662663 const int nb12, const int nb13, queue_ptr stream) {
663-
664- const int num_blocks = ne;
664+ const int num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
665665 stream->parallel_for (
666- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
666+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
667+ sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
667668 cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
668669 });
669670}
@@ -673,10 +674,10 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
673674 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
674675 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
675676 const int nb12, const int nb13, queue_ptr stream) {
676-
677- const int num_blocks = ne;
677+ const int num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
678678 stream->parallel_for (
679- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
679+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
680+ sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
680681 cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
681682 });
682683}
@@ -686,10 +687,11 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
686687 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
687688 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
688689 const int nb12, const int nb13, queue_ptr stream) {
690+ const int num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
689691
690- const int num_blocks = ne;
691692 stream->parallel_for (
692- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
693+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE),
694+ sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3 > item_ct1) {
693695 cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
694696 });
695697}
@@ -699,10 +701,9 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
699701 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
700702 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
701703 const int nb12, const int nb13, queue_ptr stream) {
702-
703- const int num_blocks = ne;
704+ const int num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE);
704705 stream->parallel_for (
705- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
706+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE), sycl::range< 3 >( 1 , 1 , SYCL_CPY_BLOCK_SIZE )), [=](sycl::nd_item<3 > item_ct1) {
706707 cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
707708 });
708709}
@@ -713,9 +714,9 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
713714 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
714715 const int nb12, const int nb13, queue_ptr stream) {
715716
716- const int num_blocks = ne ;
717- stream->parallel_for (
718- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
717+ const int num_blocks = ceil_div (ne, SYCL_CPY_BLOCK_SIZE) ;
718+ stream->parallel_for (
719+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_CPY_BLOCK_SIZE), sycl::range< 3 >( 1 , 1 , SYCL_CPY_BLOCK_SIZE )), [=](sycl::nd_item<3 > item_ct1) {
719720 cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
720721 });
721722}
0 commit comments