Skip to content

Commit dfaadfa

Browse files
committed
SYCL: Add non contiguous input support to norm kernel
1 parent 3398305 commit dfaadfa

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4160,6 +4160,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
41604160
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
41614161
#endif
41624162
case GGML_OP_NORM:
4163+
return true;
41634164
case GGML_OP_RMS_NORM:
41644165
case GGML_OP_L2_NORM:
41654166
case GGML_OP_GROUP_NORM:

ggml/src/ggml-sycl/norm.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,47 @@
11
#include "norm.hpp"
2+
#include "ggml-sycl/presets.hpp"
23

3-
static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
4-
const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
5-
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
6-
item_ct1.get_local_id(1);
7-
const int tid = item_ct1.get_local_id(2);
4+
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
5+
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
86

9-
const int nthreads = item_ct1.get_local_range(2);
10-
const int nwarps = nthreads / WARP_SIZE;
11-
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
7+
const int nrows = item_ct1.get_group_range(2);
8+
const int nchannels = item_ct1.get_group_range(1);
9+
int sample = item_ct1.get_group(0);
10+
int channel = item_ct1.get_group(1);
11+
int row = item_ct1.get_group(2);
1212

13+
int tid = item_ct1.get_local_id(2);
14+
15+
x += sample * stride_sample + channel * stride_channel + row * stride_row;
16+
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
17+
18+
sycl::float2 mean_var{0.f, 0.f};
1319
for (int col = tid; col < ncols; col += block_size) {
14-
const float xi = x[row * ncols + col];
20+
const float xi = x[col];
1521
mean_var.x() += xi;
1622
mean_var.y() += xi * xi;
1723
}
1824

1925
// sum up partial sums
2026
mean_var = warp_reduce_sum(mean_var, item_ct1);
21-
if (block_size > WARP_SIZE) {
22-
23-
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
24-
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
27+
if (block_size > WARP_SIZE) {
28+
int warp_id = tid / WARP_SIZE;
29+
int lane_id = tid % WARP_SIZE;
2530
if (lane_id == 0) {
2631
s_sum[warp_id] = mean_var;
2732
}
28-
/*
29-
DPCT1118:0: SYCL group functions and algorithms must be encountered in
30-
converged control flow. You may need to adjust the code.
31-
*/
3233
item_ct1.barrier(sycl::access::fence_space::local_space);
33-
mean_var = 0.f;
34-
size_t nreduce = nwarps / WARP_SIZE;
35-
for (size_t i = 0; i < nreduce; i += 1)
36-
{
37-
mean_var += s_sum[lane_id + i * WARP_SIZE];
38-
}
34+
35+
mean_var = s_sum[lane_id];
3936
mean_var = warp_reduce_sum(mean_var, item_ct1);
4037
}
4138

4239
const float mean = mean_var.x() / ncols;
43-
const float var = mean_var.y() / ncols - mean * mean;
40+
const float var = mean_var.y() / ncols - mean * mean;
4441
const float inv_std = sycl::rsqrt(var + eps);
4542

4643
for (int col = tid; col < ncols; col += block_size) {
47-
dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
44+
dst[col] = (x[col] - mean) * inv_std;
4845
}
4946
}
5047

@@ -224,20 +221,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
224221
}
225222
}
226223

227-
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
228-
const int nrows, const float eps,
229-
queue_ptr stream, int device) {
224+
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
225+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
226+
const float eps, queue_ptr stream, int device) {
227+
228+
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
230229
GGML_ASSERT(ncols % WARP_SIZE == 0);
231230
if (ncols < 1024) {
232-
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
231+
const sycl::range<3> block_dims(1, 1, WARP_SIZE); // Equivalent to CUDA's (WARP_SIZE, 1, 1)
233232
stream->submit([&](sycl::handler& cgh) {
234233
cgh.parallel_for(
235-
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
236-
block_dims),
234+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
237235
[=](sycl::nd_item<3> item_ct1)
238236
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
239-
norm_f32(x, dst, ncols, eps, item_ct1,
240-
nullptr, WARP_SIZE);
237+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
241238
});
242239
});
243240
}
@@ -251,16 +248,13 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
251248
info::device::max_work_group_size. Adjust the work-group size if needed.
252249
*/
253250
stream->submit([&](sycl::handler& cgh) {
254-
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
255-
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
251+
auto s_sum_acc_ct1 = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(work_group_size / WARP_SIZE), cgh);
256252

257253
cgh.parallel_for(
258-
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
259-
block_dims),
254+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
260255
[=](sycl::nd_item<3> item_ct1)
261256
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
262-
norm_f32(x, dst, ncols, eps, item_ct1,
263-
get_pointer(s_sum_acc_ct1), work_group_size);
257+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
264258
});
265259
});
266260
}
@@ -398,21 +392,27 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398392
}
399393

400394
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
395+
const ggml_tensor * src0 = dst->src[0];
401396

402397
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
403398
GGML_ASSERT(dst->type == GGML_TYPE_F32);
404399

405-
const int64_t ne00 = dst->src[0]->ne[0];
406-
const int64_t nrows = ggml_nrows(dst->src[0]);
400+
GGML_TENSOR_UNARY_OP_LOCALS
407401
dpct::queue_ptr main_stream = ctx.stream();
408402
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
409403
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
410404
float * dst_dd = static_cast<float *>(dst->data);
411405

412406
float eps;
413407
memcpy(&eps, dst->op_params, sizeof(float));
414-
415-
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
408+
GGML_ASSERT(eps >= 0.0f);
409+
const size_t ts0 = ggml_type_size(src0->type);
410+
GGML_ASSERT(nb00 == ts0);
411+
const int64_t s01 = nb01 / ts0;
412+
const int64_t s02 = nb02 / ts0;
413+
const int64_t s03 = nb03 / ts0;
414+
415+
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
416416
}
417417

418418
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {

0 commit comments

Comments
 (0)