@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9090 switch (dim) {
9191 case 0 :
92- sycl_parallel_for (stream,
93- sycl::nd_range<3 >(gridDim *
94- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
95- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
96- [=](sycl::nd_item<3 > item_ct1) {
97- concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1);
98- });
99- break ;
92+ sycl_parallel_for (stream,
93+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
94+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
95+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1); });
96+ break ;
10097 case 1 :
101- sycl_parallel_for (stream,
102- sycl::nd_range<3 >(gridDim *
103- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
105- [=](sycl::nd_item<3 > item_ct1) {
106- concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1);
107- });
108- break ;
98+ sycl_parallel_for (stream,
99+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
100+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
101+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1); });
102+ break ;
109103 // dim >=2 will be dispatched to the default path
110104 default :
111- sycl_parallel_for (stream,
112- sycl::nd_range<3 >(gridDim *
113- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115- [=](sycl::nd_item<3 > item_ct1) {
116- concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1);
117- });
118- break ;
105+ sycl_parallel_for (stream,
106+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1); });
109+ break ;
119110 }
120111}
121112
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
129120 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130121 uint64_t nb3, int32_t dim) {
131122 sycl::range<3 > gridDim (ne3, ne2, ne1);
132- sycl_parallel_for (stream,
133- sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )),
134- [=](sycl::nd_item<3 > item_ct1) {
135- int64_t i3 = item_ct1.get_group (0 );
136- int64_t i2 = item_ct1.get_group (1 );
137- int64_t i1 = item_ct1.get_group (2 );
123+ sycl_parallel_for (stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
124+ int64_t i3 = item_ct1.get_group (0 );
125+ int64_t i2 = item_ct1.get_group (1 );
126+ int64_t i1 = item_ct1.get_group (2 );
138127
139- int64_t o[4 ] = {0 , 0 , 0 , 0 };
140- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
128+ int64_t o[4 ] = { 0 , 0 , 0 , 0 };
129+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
141130
142- const float *x;
131+ const float * x;
143132
144- for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0;
145- i0 += item_ct1.get_local_range (2 )) {
133+ for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0; i0 += item_ct1.get_local_range (2 )) {
146134 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
147- x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148- (i0)*nb00);
135+ x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
149136 } else {
150- x = (const float *)(src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 +
151- (i1 - o[ 1 ]) * nb11 + (i0 - o[0 ]) * nb10);
137+ x = (const float *) (src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 + (i1 - o[ 1 ]) * nb11 +
138+ (i0 - o[0 ]) * nb10);
152139 }
153140
154141 float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
155142
156143 *y = *x;
157- }
158- });
144+ }
145+ });
159146}
160147
161148void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
0 commit comments