11#include " norm.hpp"
2- #include " ggml-sycl/presets.hpp"
32
43static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
54 const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
65
76 const int nrows = item_ct1.get_group_range (2 );
87 const int nchannels = item_ct1.get_group_range (1 );
9- int sample = item_ct1.get_group (0 );
10- int channel = item_ct1.get_group (1 );
11- int row = item_ct1.get_group (2 );
8+ const int sample = item_ct1.get_group (0 );
9+ const int channel = item_ct1.get_group (1 );
10+ const int row = item_ct1.get_group (2 );
1211
13- int tid = item_ct1.get_local_id (2 );
12+ const int tid = item_ct1.get_local_id (2 );
1413
15- x += sample * stride_sample + channel * stride_channel + row * stride_row;
14+ x += sample * stride_sample + channel * stride_channel + row * stride_row;
1615 dst += ((sample * nchannels + channel) * nrows + row) * ncols;
1716
1817 sycl::float2 mean_var{0 .f , 0 .f };
@@ -132,17 +131,25 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
132131 }
133132}
134133
135- static void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps,
136- const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
137- const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
138- item_ct1.get_local_id (1 );
134+ static void rms_norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
135+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
136+
137+ const int nrows = item_ct1.get_group_range (2 );
138+ const int nchannels = item_ct1.get_group_range (1 );
139+ const int sample = item_ct1.get_group (0 );
140+ const int channel = item_ct1.get_group (1 );
141+ const int row = item_ct1.get_group (2 );
142+
139143 const int tid = item_ct1.get_local_id (2 );
140- const int nthreads = item_ct1.get_local_range (2 );
141- const int nwarps = nthreads / WARP_SIZE;
144+
145+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
146+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
147+
148+
142149 float tmp = 0 .0f ; // partial sum for thread in warp
143150
144151 for (int col = tid; col < ncols; col += block_size) {
145- const float xi = x[row * ncols + col];
152+ const float xi = x[col];
146153 tmp += xi * xi;
147154 }
148155
@@ -155,25 +162,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
155162 if (lane_id == 0 ) {
156163 s_sum[warp_id] = tmp;
157164 }
158- /*
159- DPCT1118:3: SYCL group functions and algorithms must be encountered in
160- converged control flow. You may need to adjust the code.
161- */
165+
162166 item_ct1.barrier (sycl::access::fence_space::local_space);
163- size_t nreduce = nwarps / WARP_SIZE;
164- tmp = 0 .f ;
165- for (size_t i = 0 ; i < nreduce; i += 1 )
166- {
167- tmp += s_sum[lane_id + i * WARP_SIZE];
168- }
167+ tmp = s_sum[lane_id];
169168 tmp = warp_reduce_sum (tmp, item_ct1);
170169 }
171170
172171 const float mean = tmp / ncols;
173172 const float scale = sycl::rsqrt (mean + eps);
174173
175174 for (int col = tid; col < ncols; col += block_size) {
176- dst[row * ncols + col] = scale * x[row * ncols + col];
175+ dst[col] = scale * x[col];
177176 }
178177}
179178
@@ -307,21 +306,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
307306 }
308307}
309308
310- static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols,
311- const int nrows, const float eps,
312- queue_ptr stream, int device) {
309+ static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
310+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
313311 GGML_ASSERT (ncols % WARP_SIZE == 0 );
314312 // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
313+
314+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
315315 if (ncols < 1024 ) {
316316 const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
317317 stream->submit ([&](sycl::handler& cgh) {
318318 cgh.parallel_for (
319- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
320- block_dims),
319+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
321320 [=](sycl::nd_item<3 > item_ct1)
322321 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
323- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
324- nullptr , WARP_SIZE);
322+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
325323 });
326324 });
327325 }
@@ -338,12 +336,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
338336 sycl::local_accessor<float , 1 > s_sum_acc_ct1 (sycl::range<1 >(work_group_size / WARP_SIZE),
339337 cgh);
340338 cgh.parallel_for (
341- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
342- block_dims),
339+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
343340 [=](sycl::nd_item<3 > item_ct1)
344341 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
345- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
346- get_pointer (s_sum_acc_ct1), work_group_size);
342+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
347343 });
348344 });
349345 }
@@ -436,11 +432,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436432
437433void ggml_sycl_op_rms_norm (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438434
435+ const ggml_tensor * src0 = dst->src [0 ];
439436 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
440437 GGML_ASSERT (dst->type == GGML_TYPE_F32);
441438
442- const int64_t ne00 = dst->src [0 ]->ne [0 ];
443- const int64_t nrows = ggml_nrows (dst->src [0 ]);
444439 dpct::queue_ptr main_stream = ctx.stream ();
445440 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
446441
@@ -450,7 +445,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450445 float eps;
451446 memcpy (&eps, dst->op_params , sizeof (float ));
452447
453- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
448+ GGML_TENSOR_UNARY_OP_LOCALS
449+ const size_t ts0 = ggml_type_size (src0->type );
450+ GGML_ASSERT (nb00 == ts0);
451+ const int64_t s01 = nb01 / ts0;
452+ const int64_t s02 = nb02 / ts0;
453+ const int64_t s03 = nb03 / ts0;
454+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
454455}
455456
456457void ggml_sycl_op_l2_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
0 commit comments