Skip to content

Commit 4b76ed1

Browse files
tamarPaltamarPal
authored andcommitted
sycl: add ROLL operation support
- Implement ggml_sycl_roll function for F32 tensors - Add multi-axis roll operation with SYCL kernel - Support all 4 tensor dimensions with proper shift normalization - Add roll.cpp and roll.hpp to SYCL backend - Update backend dispatch and supports_op for GGML_OP_ROLL - Tests: 17662/17662 pass with identical CPU reference results
1 parent 55754be commit 4b76ed1

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "pad.hpp"
3333
#include "quantize.hpp"
3434
#include "quants.hpp"
35+
#include "roll.hpp"
3536
#include "rope.hpp"
3637
#include "set_rows.hpp"
3738
#include "softmax.hpp"

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
38363836
case GGML_OP_GATED_LINEAR_ATTN:
38373837
ggml_sycl_op_gated_linear_attn(ctx, dst);
38383838
break;
3839+
case GGML_OP_ROLL:
3840+
ggml_sycl_roll(ctx, dst);
3841+
break;
38393842
case GGML_OP_ARANGE:
38403843
ggml_sycl_arange(ctx, dst);
38413844
break;
@@ -4491,6 +4494,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44914494
case GGML_OP_RWKV_WKV7:
44924495
case GGML_OP_GATED_LINEAR_ATTN:
44934496
return true;
4497+
case GGML_OP_ROLL:
4498+
return op->type == GGML_TYPE_F32;
44944499
case GGML_OP_ARANGE:
44954500
return op->type == GGML_TYPE_F32;
44964501
default:

ggml/src/ggml-sycl/roll.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "roll.hpp"
2+
#include "common.hpp"
3+
4+
using namespace sycl;
5+
6+
static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor *dst,
7+
int shift0, int shift1, int shift2, int shift3) {
8+
if (!src || !dst) throw std::runtime_error("null tensor");
9+
if (src->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32)
10+
throw std::runtime_error("only F32 supported in SYCL roll");
11+
12+
const int64_t ne0 = dst->ne[0];
13+
const int64_t ne1 = dst->ne[1];
14+
const int64_t ne2 = dst->ne[2];
15+
const int64_t ne3 = dst->ne[3];
16+
17+
if (ne0 != src->ne[0] || ne1 != src->ne[1] || ne2 != src->ne[2] || ne3 != src->ne[3])
18+
throw std::runtime_error("src/dst shape mismatch");
19+
20+
// Normalize shifts to be within bounds
21+
const int64_t sh0 = ne0 > 0 ? ((int64_t)shift0 % ne0 + ne0) % ne0 : 0;
22+
const int64_t sh1 = ne1 > 0 ? ((int64_t)shift1 % ne1 + ne1) % ne1 : 0;
23+
const int64_t sh2 = ne2 > 0 ? ((int64_t)shift2 % ne2 + ne2) % ne2 : 0;
24+
const int64_t sh3 = ne3 > 0 ? ((int64_t)shift3 % ne3 + ne3) % ne3 : 0;
25+
26+
const float *src_d = (const float*) src->data;
27+
float *dst_d = (float*) dst->data;
28+
29+
if (!src_d || !dst_d) throw std::runtime_error("null data pointers");
30+
31+
q.submit([&](handler &h) {
32+
range<3> r((size_t)ne3, (size_t)ne2, (size_t)ne1);
33+
h.parallel_for(r, [=](id<3> idx) {
34+
const int64_t i3 = idx[0];
35+
const int64_t i2 = idx[1];
36+
const int64_t i1 = idx[2];
37+
38+
for (int64_t i0 = 0; i0 < ne0; i0++) {
39+
const int64_t idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
40+
41+
// Apply shift to each dimension
42+
const int64_t src_i0 = (i0 - sh0 + ne0) % ne0;
43+
const int64_t src_i1 = (i1 - sh1 + ne1) % ne1;
44+
const int64_t src_i2 = (i2 - sh2 + ne2) % ne2;
45+
const int64_t src_i3 = (i3 - sh3 + ne3) % ne3;
46+
47+
const int64_t idx_src = src_i0 + src_i1 * ne0 +
48+
src_i2 * ne0 * ne1 + src_i3 * ne0 * ne1 * ne2;
49+
50+
dst_d[idx_dst] = src_d[idx_src];
51+
}
52+
});
53+
}).wait();
54+
}
55+
56+
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
57+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
58+
59+
const ggml_tensor *src = dst->src[0];
60+
61+
const int32_t *params = (const int32_t *)dst->op_params;
62+
const int shift0 = params[0];
63+
const int shift1 = params[1];
64+
const int shift2 = params[2];
65+
const int shift3 = params[3];
66+
67+
// Check if all shifts are zero
68+
if (shift0 == 0 && shift1 == 0 && shift2 == 0 && shift3 == 0) {
69+
const size_t nb = ggml_nbytes(src);
70+
queue *q = ctx.stream();
71+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb).wait()));
72+
return;
73+
}
74+
75+
try {
76+
queue *q = ctx.stream();
77+
kernel_roll_multi_axis(*q, src, dst, shift0, shift1, shift2, shift3);
78+
} catch (const std::exception &e) {
79+
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
80+
throw;
81+
}
82+
}

ggml/src/ggml-sycl/roll.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//
2+
// MIT license
3+
// Copyright (C) 2024 Intel Corporation
4+
// SPDX-License-Identifier: MIT
5+
//
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
13+
#ifndef GGML_SYCL_ROLL_HPP
14+
#define GGML_SYCL_ROLL_HPP
15+
16+
#include "common.hpp"
17+
18+
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
19+
20+
#endif // GGML_SYCL_ROLL_HPP

0 commit comments

Comments
 (0)