Skip to content

Commit f7d05e7

Browse files
committed
SYCL: optimize repeat_back kernel
1 parent d3e88bc commit f7d05e7

File tree

1 file changed

+63
-39
lines changed

1 file changed

+63
-39
lines changed

ggml/src/ggml-sycl/repeat_back.cpp

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,79 @@
33

44
#include "common.hpp"
55

6-
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
6+
#define GGML_ASSERT_TENSOR_FITS_INT(t) \
7+
GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
78

9+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
810
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
911
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1012

1113
const float * src0_dd = (const float *) dst->src[0]->data;
1214
float * dst_dd = (float *) dst->data;
1315

14-
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
15-
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
16-
ne03 = dst->src[0]->ne[3];
16+
GGML_ASSERT_TENSOR_FITS_INT(dst);
17+
GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);
18+
19+
const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
20+
const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
21+
ne03 = dst->src[0]->ne[3];
22+
23+
const int nr0 = ne00 / ne0;
24+
const int nr1 = ne01 / ne1;
25+
const int nr2 = ne02 / ne2;
26+
const int nr3 = ne03 / ne3;
27+
28+
const int nb0 = dst->src[0]->nb[0];
29+
const int nb1 = dst->src[0]->nb[1];
30+
const int nb2 = dst->src[0]->nb[2];
31+
const int nb3 = dst->src[0]->nb[3];
32+
33+
const char * base = (const char *) src0_dd;
1734

18-
const int nr0 = (int) (ne00 / ne0);
19-
const int nr1 = (int) (ne01 / ne1);
20-
const int nr2 = (int) (ne02 / ne2);
21-
const int nr3 = (int) (ne03 / ne3);
35+
const size_t total = (size_t) ne0 * ne1 * ne2 * ne3;
36+
constexpr int BLOCK_SIZE = 256;
37+
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
2238

23-
const size_t total = ne0 * ne1 * ne2 * ne3;
24-
const int BLOCK_SIZE = 256;
25-
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
39+
// Precompute inverse sizes to replace integer divisions with multiplications
40+
const float inv_ne0 = 1.0f / ne0;
41+
const float inv_ne_01 = 1.0f / (ne0 * ne1);
42+
const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2);
43+
const int repeat_count = nr0 * nr1 * nr2 * nr3;
2644

2745
queue_ptr stream = ctx.stream();
2846

29-
stream->parallel_for(
30-
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
31-
[=](sycl::nd_item<1> item_ct1) {
32-
const size_t i = item_ct1.get_global_linear_id();
33-
if (i >= total) {
34-
return;
35-
}
36-
37-
const int i0 = i % ne0;
38-
const int i1 = (i / ne0) % ne1;
39-
const int i2 = (i / (ne0 * ne1)) % ne2;
40-
const int i3 = i / (ne0 * ne1 * ne2);
41-
42-
float acc = 0.0f;
43-
44-
for (int j3 = 0; j3 < nr3; ++j3) {
45-
for (int j2 = 0; j2 < nr2; ++j2) {
46-
for (int j1 = 0; j1 < nr1; ++j1) {
47-
for (int j0 = 0; j0 < nr0; ++j0) {
48-
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
49-
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
50-
}
51-
}
52-
}
53-
}
54-
55-
dst_dd[i] = acc;
56-
});
47+
stream->parallel_for(sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
48+
[=](sycl::nd_item<1> item_ct1) {
49+
const size_t i = item_ct1.get_global_linear_id();
50+
51+
if (i >= total) {
52+
return;
53+
}
54+
55+
// Compute multidimensional indices (i0,i1,i2,i3) from the flattened linear index i
56+
const int i3 = (int) (i * inv_ne_012);
57+
const int i2 = (int) (i * inv_ne_01) - i3 * ne2;
58+
const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;
59+
const int i0 = i - (int) (i * inv_ne0) * ne0;
60+
61+
int j0 = 0, j1 = 0, j2 = 0, j3 = 0;
62+
float acc = 0.0f;
63+
64+
for (int j = 0; j < repeat_count; ++j) {
65+
const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
66+
(i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
67+
acc += *ptr;
68+
69+
// Manual carry propagation simulates nested loops efficiently
70+
int carry = (++j0 >= nr0);
71+
j0 -= carry * nr0;
72+
carry = (carry && (++j1 >= nr1));
73+
j1 -= carry * nr1;
74+
carry = (carry && (++j2 >= nr2));
75+
j2 -= carry * nr2;
76+
j3 += carry;
77+
}
78+
dst_dd[i] = acc;
79+
});
80+
5781
}

0 commit comments

Comments
 (0)