1111//
1212
1313#include " concat.hpp"
14- #include " common.hpp"
1514
1615static inline size_t elem_size (ggml_type t) {
1716 return ggml_type_size (t) / ggml_blck_size (t);
@@ -97,21 +96,18 @@ static void concat_T_sycl(const T *x, const T *y, T *dst,
9796 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9897 switch (dim) {
9998 case 0 :
100- sycl_parallel_for (stream,
101- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
99+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
102100 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
103101 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });
104102 break ;
105103 case 1 :
106- sycl_parallel_for (stream,
107- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
108105 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
109106 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });
110107 break ;
111108 // dim >=2 will be dispatched to the default path
112109 default :
113- sycl_parallel_for (stream,
114- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
110+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
115111 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
116112 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });
117113 break ;
@@ -129,7 +125,7 @@ static void concat_T_sycl_non_cont(
129125 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130126 uint64_t nb3, int32_t dim) {
131127 sycl::range<3 > gridDim (ne3, ne2, ne1);
132- sycl_parallel_for ( stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
128+ stream-> parallel_for ( sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
133129 int64_t i3 = item_ct1.get_group (0 );
134130 int64_t i2 = item_ct1.get_group (1 );
135131 int64_t i1 = item_ct1.get_group (2 );
0 commit comments