@@ -89,24 +89,33 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9090 switch (dim) {
9191 case 0 :
92- sycl_parallel_for (stream,
93- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
94- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
95- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1); });
96- break ;
92+ stream->parallel_for (
93+ sycl::nd_range<3 >(gridDim *
94+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
95+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
96+ [=](sycl::nd_item<3 > item_ct1) {
97+ concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1);
98+ });
99+ break ;
97100 case 1 :
98- sycl_parallel_for (stream,
99- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
100- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
101- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1); });
102- break ;
101+ stream->parallel_for (
102+ sycl::nd_range<3 >(gridDim *
103+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
105+ [=](sycl::nd_item<3 > item_ct1) {
106+ concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1);
107+ });
108+ break ;
103109 // dim >=2 will be dispatched to the default path
104110 default :
105- sycl_parallel_for (stream,
106- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1); });
109- break ;
111+ stream->parallel_for (
112+ sycl::nd_range<3 >(gridDim *
113+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115+ [=](sycl::nd_item<3 > item_ct1) {
116+ concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1);
117+ });
118+ break ;
110119 }
111120}
112121
@@ -120,7 +129,7 @@ static void concat_f32_sycl_non_cont(
120129 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
121130 uint64_t nb3, int32_t dim) {
122131 sycl::range<3 > gridDim (ne3, ne2, ne1);
123- sycl_parallel_for ( stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
132+ stream-> parallel_for ( sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
124133 int64_t i3 = item_ct1.get_group (0 );
125134 int64_t i2 = item_ct1.get_group (1 );
126135 int64_t i1 = item_ct1.get_group (2 );
0 commit comments