11#include " norm.hpp"
2+ #include " ggml-sycl/common.hpp"
3+ #include " ggml-sycl/presets.hpp"
24
3- static void norm_f32 (const float * x, float * dst, const int ncols, const float eps ,
4- const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
5- const int row = item_ct1. get_group ( 2 ) * item_ct1. get_local_range ( 1 ) +
6- item_ct1.get_local_id ( 1 );
7- const int tid = item_ct1.get_local_id ( 2 );
5+ static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel ,
6+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
7+
8+ const int nrows = item_ct1.get_group_range ( 2 );
9+ const int nchannels = item_ct1.get_group_range ( 1 );
810
911 const int nthreads = item_ct1.get_local_range (2 );
12+ const int sample = item_ct1.get_group (0 );
13+ const int channel = item_ct1.get_group (1 );
14+ const int row = item_ct1.get_group (2 );
15+
16+ const int tid = item_ct1.get_local_id (2 );
1017 const int nwarps = nthreads / WARP_SIZE;
18+
19+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
20+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21+
22+ x += strided_offset;
23+ dst += packed_offset;
24+
1125 sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
1226
1327 for (int col = tid; col < ncols; col += block_size) {
14- const float xi = x[row * ncols + col];
28+ const float xi = x[col];
1529 mean_var.x () += xi;
1630 mean_var.y () += xi * xi;
1731 }
1832
1933 // sum up partial sums
2034 mean_var = warp_reduce_sum (mean_var, item_ct1);
21- if (block_size > WARP_SIZE) {
22-
23- int warp_id = item_ct1. get_local_id ( 2 ) / WARP_SIZE ;
24- int lane_id = item_ct1. get_local_id ( 2 ) % WARP_SIZE ;
25- if (lane_id == 0 ) {
26- s_sum[warp_id ] = mean_var;
35+ if (block_size > WARP_SIZE) {
36+ const auto sub_group = item_ct1. get_sub_group ();
37+ const auto sg_id = sub_group. get_group_linear_id () ;
38+ const auto wi_in_sg = sub_group. get_local_linear_id () ;
39+ if (wi_in_sg == 0 ) {
40+ s_sum[sg_id ] = mean_var;
2741 }
28- /*
29- DPCT1118:0: SYCL group functions and algorithms must be encountered in
30- converged control flow. You may need to adjust the code.
31- */
3242 item_ct1.barrier (sycl::access::fence_space::local_space);
3343 mean_var = 0 .f ;
34- size_t nreduce = nwarps / WARP_SIZE;
44+ const size_t nreduce = ceil_div ( nwarps, WARP_SIZE) ;
3545 for (size_t i = 0 ; i < nreduce; i += 1 )
3646 {
37- mean_var += s_sum[lane_id + i * WARP_SIZE];
47+ mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
3848 }
3949 mean_var = warp_reduce_sum (mean_var, item_ct1);
4050 }
@@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
4454 const float inv_std = sycl::rsqrt (var + eps);
4555
4656 for (int col = tid; col < ncols; col += block_size) {
47- dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
57+ dst[col] = (x[col] - mean) * inv_std;
4858 }
4959}
5060
@@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
135145 }
136146}
137147
138- static void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps,
139- const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
140- const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
141- item_ct1.get_local_id (1 );
142- const int tid = item_ct1.get_local_id (2 );
148+ static void rms_norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
150+
151+ const int nrows = item_ct1.get_group_range (2 );
152+ const int nchannels = item_ct1.get_group_range (1 );
153+
154+ const int sample = item_ct1.get_group (0 );
155+ const int channel = item_ct1.get_group (1 );
156+ const int row = item_ct1.get_group (2 );
157+
143158 const int nthreads = item_ct1.get_local_range (2 );
159+
160+ const int tid = item_ct1.get_local_id (2 );
144161 const int nwarps = nthreads / WARP_SIZE;
162+
163+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
165+
166+ x += strided_offset;
167+ dst += packed_offset;
168+
169+
145170 float tmp = 0 .0f ; // partial sum for thread in warp
146171
147172 for (int col = tid; col < ncols; col += block_size) {
148- const float xi = x[row * ncols + col];
173+ const float xi = x[col];
149174 tmp += xi * xi;
150175 }
151176
152177 // sum up partial sums
153178 tmp = warp_reduce_sum (tmp, item_ct1);
154179 if (block_size > WARP_SIZE) {
155-
156- int warp_id = item_ct1. get_local_id ( 2 ) / WARP_SIZE ;
157- int lane_id = item_ct1. get_local_id ( 2 ) % WARP_SIZE ;
158- if (lane_id == 0 ) {
159- s_sum[warp_id ] = tmp;
180+ const auto sub_group = item_ct1. get_sub_group ();
181+ const auto sg_id = sub_group. get_group_linear_id () ;
182+ const auto wi_in_sg = sub_group. get_local_linear_id () ;
183+ if (wi_in_sg == 0 ) {
184+ s_sum[sg_id ] = tmp;
160185 }
161- /*
162- DPCT1118:3: SYCL group functions and algorithms must be encountered in
163- converged control flow. You may need to adjust the code.
164- */
186+
165187 item_ct1.barrier (sycl::access::fence_space::local_space);
166- size_t nreduce = nwarps / WARP_SIZE;
188+ const size_t nreduce = ceil_div ( nwarps, WARP_SIZE) ;
167189 tmp = 0 .f ;
168190 for (size_t i = 0 ; i < nreduce; i += 1 )
169191 {
170- tmp += s_sum[lane_id + i * WARP_SIZE];
192+ tmp += s_sum[wi_in_sg + i * WARP_SIZE];
171193 }
172194 tmp = warp_reduce_sum (tmp, item_ct1);
173195 }
@@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
176198 const float scale = sycl::rsqrt (mean + eps);
177199
178200 for (int col = tid; col < ncols; col += block_size) {
179- dst[row * ncols + col] = scale * x[row * ncols + col];
201+ dst[col] = scale * x[col];
180202 }
181203}
182204
@@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
224246 }
225247}
226248
227- static void norm_f32_sycl (const float * x, float * dst, const int ncols,
228- const int nrows, const float eps,
229- queue_ptr stream, int device) {
249+ static void norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
250+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
251+ const float eps, queue_ptr stream, int device) {
252+
253+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
230254 GGML_ASSERT (ncols % WARP_SIZE == 0 );
231255 if (ncols < 1024 ) {
232256 const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
233257 stream->submit ([&](sycl::handler& cgh) {
234258 cgh.parallel_for (
235- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
236- block_dims),
259+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
237260 [=](sycl::nd_item<3 > item_ct1)
238261 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
239- norm_f32 (x, dst, ncols, eps, item_ct1,
240- nullptr , WARP_SIZE);
262+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
241263 });
242264 });
243265 }
@@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
252274 */
253275 stream->submit ([&](sycl::handler& cgh) {
254276 sycl::local_accessor<sycl::float2, 1 > s_sum_acc_ct1 (
255- sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
256-
277+ sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
257278 cgh.parallel_for (
258- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
259- block_dims),
279+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
260280 [=](sycl::nd_item<3 > item_ct1)
261281 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
262- norm_f32 (x, dst, ncols, eps, item_ct1,
263- get_pointer (s_sum_acc_ct1), work_group_size);
282+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
264283 });
265284 });
266285 }
@@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
313332 }
314333}
315334
316- static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols,
317- const int nrows, const float eps,
318- queue_ptr stream, int device) {
335+ static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
319337 GGML_ASSERT (ncols % WARP_SIZE == 0 );
320338 // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
339+
340+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
321341 if (ncols < 1024 ) {
322342 const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
323343 stream->submit ([&](sycl::handler& cgh) {
324344 cgh.parallel_for (
325- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
326- block_dims),
345+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
327346 [=](sycl::nd_item<3 > item_ct1)
328347 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
329- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
330- nullptr , WARP_SIZE);
348+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
331349 });
332350 });
333351 }
@@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
344362 sycl::local_accessor<float , 1 > s_sum_acc_ct1 (sycl::range<1 >(work_group_size / WARP_SIZE),
345363 cgh);
346364 cgh.parallel_for (
347- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
348- block_dims),
365+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
349366 [=](sycl::nd_item<3 > item_ct1)
350367 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
351- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
352- get_pointer (s_sum_acc_ct1), work_group_size);
368+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
353369 });
354370 });
355371 }
@@ -398,21 +414,27 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398414}
399415
400416void ggml_sycl_op_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
417+ const ggml_tensor * src0 = dst->src [0 ];
401418
402419 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
403420 GGML_ASSERT (dst->type == GGML_TYPE_F32);
404421
405- const int64_t ne00 = dst->src [0 ]->ne [0 ];
406- const int64_t nrows = ggml_nrows (dst->src [0 ]);
422+ GGML_TENSOR_UNARY_OP_LOCALS
407423 dpct::queue_ptr main_stream = ctx.stream ();
408424 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
409425 const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
410426 float * dst_dd = static_cast <float *>(dst->data );
411427
412428 float eps;
413429 memcpy (&eps, dst->op_params , sizeof (float ));
414-
415- norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
430+ GGML_ASSERT (eps >= 0 .0f );
431+ const size_t ts0 = ggml_type_size (src0->type );
432+ GGML_ASSERT (nb00 == ts0);
433+ const int64_t s01 = nb01 / ts0;
434+ const int64_t s02 = nb02 / ts0;
435+ const int64_t s03 = nb03 / ts0;
436+
437+ norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
416438}
417439
418440void ggml_sycl_op_group_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436458
437459void ggml_sycl_op_rms_norm (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438460
461+ const ggml_tensor * src0 = dst->src [0 ];
439462 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
440463 GGML_ASSERT (dst->type == GGML_TYPE_F32);
441464
442- const int64_t ne00 = dst->src [0 ]->ne [0 ];
443- const int64_t nrows = ggml_nrows (dst->src [0 ]);
444465 dpct::queue_ptr main_stream = ctx.stream ();
445466 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
446467
@@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450471 float eps;
451472 memcpy (&eps, dst->op_params , sizeof (float ));
452473
453- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
474+ GGML_TENSOR_UNARY_OP_LOCALS
475+ const size_t ts0 = ggml_type_size (src0->type );
476+ GGML_ASSERT (nb00 == ts0);
477+ const int64_t s01 = nb01 / ts0;
478+ const int64_t s02 = nb02 / ts0;
479+ const int64_t s03 = nb03 / ts0;
480+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
454481}
455482
456483void ggml_sycl_op_l2_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
0 commit comments