|
3 | 3 |
|
4 | 4 | #include "common.hpp" |
5 | 5 |
|
6 | | -void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 6 | +#define GGML_ASSERT_TENSOR_FITS_INT(t) \ |
| 7 | + GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX) |
7 | 8 |
|
| 9 | +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
8 | 10 | GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
9 | 11 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
10 | 12 |
|
11 | 13 | const float * src0_dd = (const float *) dst->src[0]->data; |
12 | 14 | float * dst_dd = (float *) dst->data; |
13 | 15 |
|
14 | | - const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; |
15 | | - const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], |
16 | | - ne03 = dst->src[0]->ne[3]; |
| 16 | + GGML_ASSERT_TENSOR_FITS_INT(dst); |
| 17 | + GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]); |
| 18 | + |
| 19 | + const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; |
| 20 | + const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], |
| 21 | + ne03 = dst->src[0]->ne[3]; |
| 22 | + |
| 23 | + const int nr0 = ne00 / ne0; |
| 24 | + const int nr1 = ne01 / ne1; |
| 25 | + const int nr2 = ne02 / ne2; |
| 26 | + const int nr3 = ne03 / ne3; |
| 27 | + |
| 28 | + const int nb0 = dst->src[0]->nb[0]; |
| 29 | + const int nb1 = dst->src[0]->nb[1]; |
| 30 | + const int nb2 = dst->src[0]->nb[2]; |
| 31 | + const int nb3 = dst->src[0]->nb[3]; |
| 32 | + |
| 33 | + const char * base = (const char *) src0_dd; |
17 | 34 |
|
18 | | - const int nr0 = (int) (ne00 / ne0); |
19 | | - const int nr1 = (int) (ne01 / ne1); |
20 | | - const int nr2 = (int) (ne02 / ne2); |
21 | | - const int nr3 = (int) (ne03 / ne3); |
| 35 | + const size_t total = (size_t) ne0 * ne1 * ne2 * ne3; |
| 36 | + constexpr int BLOCK_SIZE = 256; |
| 37 | + const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; |
22 | 38 |
|
23 | | - const size_t total = ne0 * ne1 * ne2 * ne3; |
24 | | - const int BLOCK_SIZE = 256; |
25 | | - const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; |
| 39 | + // Precompute inverse sizes to replace integer divisions with multiplications |
| 40 | + const float inv_ne0 = 1.0f / ne0; |
| 41 | + const float inv_ne_01 = 1.0f / (ne0 * ne1); |
| 42 | + const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2); |
| 43 | + const int repeat_count = nr0 * nr1 * nr2 * nr3; |
26 | 44 |
|
27 | 45 | queue_ptr stream = ctx.stream(); |
28 | 46 |
|
29 | | - stream->parallel_for( |
30 | | - sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), |
31 | | - [=](sycl::nd_item<1> item_ct1) { |
32 | | - const size_t i = item_ct1.get_global_linear_id(); |
33 | | - if (i >= total) { |
34 | | - return; |
35 | | - } |
36 | | - |
37 | | - const int i0 = i % ne0; |
38 | | - const int i1 = (i / ne0) % ne1; |
39 | | - const int i2 = (i / (ne0 * ne1)) % ne2; |
40 | | - const int i3 = i / (ne0 * ne1 * ne2); |
41 | | - |
42 | | - float acc = 0.0f; |
43 | | - |
44 | | - for (int j3 = 0; j3 < nr3; ++j3) { |
45 | | - for (int j2 = 0; j2 < nr2; ++j2) { |
46 | | - for (int j1 = 0; j1 < nr1; ++j1) { |
47 | | - for (int j0 = 0; j0 < nr0; ++j0) { |
48 | | - acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 + |
49 | | - (i3 + j3 * ne3) * ne00 * ne01 * ne02]; |
50 | | - } |
51 | | - } |
52 | | - } |
53 | | - } |
54 | | - |
55 | | - dst_dd[i] = acc; |
56 | | - }); |
| 47 | + stream->parallel_for(sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), |
| 48 | + [=](sycl::nd_item<1> item_ct1) { |
| 49 | + const size_t i = item_ct1.get_global_linear_id(); |
| 50 | + |
| 51 | + if (i >= total) { |
| 52 | + return; |
| 53 | + } |
| 54 | + |
| 55 | + // Compute multidimensional indices (i0,i1,i2,i3) from the flattened linear index i |
| 56 | + const int i3 = (int) (i * inv_ne_012); |
| 57 | + const int i2 = (int) (i * inv_ne_01) - i3 * ne2; |
| 58 | + const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1; |
| 59 | + const int i0 = i - (int) (i * inv_ne0) * ne0; |
| 60 | + |
| 61 | + int j0 = 0, j1 = 0, j2 = 0, j3 = 0; |
| 62 | + float acc = 0.0f; |
| 63 | + |
| 64 | + for (int j = 0; j < repeat_count; ++j) { |
| 65 | + const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 + |
| 66 | + (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3); |
| 67 | + acc += *ptr; |
| 68 | + |
| 69 | + // Manual carry propagation simulates nested loops efficiently |
| 70 | + int carry = (++j0 >= nr0); |
| 71 | + j0 -= carry * nr0; |
| 72 | + carry = (carry && (++j1 >= nr1)); |
| 73 | + j1 -= carry * nr1; |
| 74 | + carry = (carry && (++j2 >= nr2)); |
| 75 | + j2 -= carry * nr2; |
| 76 | + j3 += carry; |
| 77 | + } |
| 78 | + dst_dd[i] = acc; |
| 79 | + }); |
| 80 | + |
57 | 81 | } |
0 commit comments