1111//
1212
1313#include " concat.hpp"
14- #include " common.hpp"
1514
1615static inline size_t elem_size (ggml_type t) {
1716 return ggml_type_size (t) / ggml_blck_size (t);
@@ -96,21 +95,18 @@ static void concat_T_sycl(const T *x, const T *y, T *dst,
9695 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9796 switch (dim) {
9897 case 0 :
99- sycl_parallel_for (stream,
100- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
98+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
10199 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
102100 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });
103101 break ;
104102 case 1 :
105- sycl_parallel_for (stream,
106- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
103+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107104 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108105 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });
109106 break ;
110107 // dim >=2 will be dispatched to the default path
111108 default :
112- sycl_parallel_for (stream,
113- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
109+ stream->parallel_for (sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114110 sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115111 [=](sycl::nd_item<3 > item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });
116112 break ;
@@ -128,7 +124,7 @@ static void concat_T_sycl_non_cont(
128124 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
129125 uint64_t nb3, int32_t dim) {
130126 sycl::range<3 > gridDim (ne3, ne2, ne1);
131- sycl_parallel_for ( stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
127+ stream-> parallel_for ( sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
132128 int64_t i3 = item_ct1.get_group (0 );
133129 int64_t i2 = item_ct1.get_group (1 );
134130 int64_t i1 = item_ct1.get_group (2 );
0 commit comments