11#include " norm.hpp"
2+ #include " ggml-sycl/presets.hpp"
23
3- static void norm_f32 (const float * x, float * dst, const int ncols, const float eps,
4- const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
5- const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
6- item_ct1.get_local_id (1 );
7- const int tid = item_ct1.get_local_id (2 );
4+ static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
5+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
86
9- const int nthreads = item_ct1.get_local_range (2 );
10- const int nwarps = nthreads / WARP_SIZE;
11- sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
7+ const int nrows = item_ct1.get_group_range (2 );
8+ const int nchannels = item_ct1.get_group_range (1 );
9+ int sample = item_ct1.get_group (0 );
10+ int channel = item_ct1.get_group (1 );
11+ int row = item_ct1.get_group (2 );
1212
13+ int tid = item_ct1.get_local_id (2 );
14+
15+ x += sample * stride_sample + channel * stride_channel + row * stride_row;
16+ dst += ((sample * nchannels + channel) * nrows + row) * ncols;
17+
18+ sycl::float2 mean_var{0 .f , 0 .f };
1319 for (int col = tid; col < ncols; col += block_size) {
14- const float xi = x[row * ncols + col];
20+ const float xi = x[col];
1521 mean_var.x () += xi;
1622 mean_var.y () += xi * xi;
1723 }
1824
1925 // sum up partial sums
2026 mean_var = warp_reduce_sum (mean_var, item_ct1);
21- if (block_size > WARP_SIZE) {
22-
23- int warp_id = item_ct1.get_local_id (2 ) / WARP_SIZE;
24- int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
27+ if (block_size > WARP_SIZE) {
28+ int warp_id = tid / WARP_SIZE;
29+ int lane_id = tid % WARP_SIZE;
2530 if (lane_id == 0 ) {
2631 s_sum[warp_id] = mean_var;
2732 }
28- /*
29- DPCT1118:0: SYCL group functions and algorithms must be encountered in
30- converged control flow. You may need to adjust the code.
31- */
3233 item_ct1.barrier (sycl::access::fence_space::local_space);
33- mean_var = 0 .f ;
34- size_t nreduce = nwarps / WARP_SIZE;
35- for (size_t i = 0 ; i < nreduce; i += 1 )
36- {
37- mean_var += s_sum[lane_id + i * WARP_SIZE];
38- }
34+
35+ mean_var = s_sum[lane_id];
3936 mean_var = warp_reduce_sum (mean_var, item_ct1);
4037 }
4138
4239 const float mean = mean_var.x () / ncols;
43- const float var = mean_var.y () / ncols - mean * mean;
40+ const float var = mean_var.y () / ncols - mean * mean;
4441 const float inv_std = sycl::rsqrt (var + eps);
4542
4643 for (int col = tid; col < ncols; col += block_size) {
47- dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
44+ dst[col] = (x[col] - mean) * inv_std;
4845 }
4946}
5047
@@ -224,20 +221,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
224221 }
225222}
226223
227- static void norm_f32_sycl (const float * x, float * dst, const int ncols,
228- const int nrows, const float eps,
229- queue_ptr stream, int device) {
224+ static void norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
225+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
226+ const float eps, queue_ptr stream, int device) {
227+
228+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
230229 GGML_ASSERT (ncols % WARP_SIZE == 0 );
231230 if (ncols < 1024 ) {
232- const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
231+ const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE); // Equivalent to CUDA's (WARP_SIZE, 1, 1)
233232 stream->submit ([&](sycl::handler& cgh) {
234233 cgh.parallel_for (
235- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
236- block_dims),
234+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
237235 [=](sycl::nd_item<3 > item_ct1)
238236 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
239- norm_f32 (x, dst, ncols, eps, item_ct1,
240- nullptr , WARP_SIZE);
237+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
241238 });
242239 });
243240 }
@@ -251,16 +248,13 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
251248 info::device::max_work_group_size. Adjust the work-group size if needed.
252249 */
253250 stream->submit ([&](sycl::handler& cgh) {
254- sycl::local_accessor<sycl::float2, 1 > s_sum_acc_ct1 (
255- sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
251+ auto s_sum_acc_ct1 = sycl::local_accessor<sycl::float2, 1 >(sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
256252
257253 cgh.parallel_for (
258- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
259- block_dims),
254+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
260255 [=](sycl::nd_item<3 > item_ct1)
261256 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
262- norm_f32 (x, dst, ncols, eps, item_ct1,
263- get_pointer (s_sum_acc_ct1), work_group_size);
257+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
264258 });
265259 });
266260 }
@@ -398,21 +392,27 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398392}
399393
400394void ggml_sycl_op_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
395+ const ggml_tensor * src0 = dst->src [0 ];
401396
402397 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
403398 GGML_ASSERT (dst->type == GGML_TYPE_F32);
404399
405- const int64_t ne00 = dst->src [0 ]->ne [0 ];
406- const int64_t nrows = ggml_nrows (dst->src [0 ]);
400+ GGML_TENSOR_UNARY_OP_LOCALS
407401 dpct::queue_ptr main_stream = ctx.stream ();
408402 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
409403 const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
410404 float * dst_dd = static_cast <float *>(dst->data );
411405
412406 float eps;
413407 memcpy (&eps, dst->op_params , sizeof (float ));
414-
415- norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
408+ GGML_ASSERT (eps >= 0 .0f );
409+ const size_t ts0 = ggml_type_size (src0->type );
410+ GGML_ASSERT (nb00 == ts0);
411+ const int64_t s01 = nb01 / ts0;
412+ const int64_t s02 = nb02 / ts0;
413+ const int64_t s03 = nb03 / ts0;
414+
415+ norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
416416}
417417
418418void ggml_sycl_op_group_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
0 commit comments