22
33#include " common.hpp"
44
5- void ggml_sycl_op_repeat_back (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
5+ #define GGML_ASSERT_TENSOR_FITS_INT (t ) \
6+ GGML_ASSERT ((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
67
8+ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
79 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
810 GGML_ASSERT (dst->type == GGML_TYPE_F32);
911
1012 const float * src0_dd = (const float *) dst->src [0 ]->data ;
1113 float * dst_dd = (float *) dst->data ;
1214
13- const int64_t ne0 = dst->ne [0 ], ne1 = dst->ne [1 ], ne2 = dst->ne [2 ], ne3 = dst->ne [3 ];
14- const int64_t ne00 = dst->src [0 ]->ne [0 ], ne01 = dst->src [0 ]->ne [1 ], ne02 = dst->src [0 ]->ne [2 ],
15- ne03 = dst->src [0 ]->ne [3 ];
15+ GGML_ASSERT_TENSOR_FITS_INT (dst);
16+ GGML_ASSERT_TENSOR_FITS_INT (dst->src [0 ]);
17+
18+ const int ne0 = dst->ne [0 ], ne1 = dst->ne [1 ], ne2 = dst->ne [2 ], ne3 = dst->ne [3 ];
19+ const int ne00 = dst->src [0 ]->ne [0 ], ne01 = dst->src [0 ]->ne [1 ], ne02 = dst->src [0 ]->ne [2 ],
20+ ne03 = dst->src [0 ]->ne [3 ];
21+
22+ const int nr0 = ne00 / ne0;
23+ const int nr1 = ne01 / ne1;
24+ const int nr2 = ne02 / ne2;
25+ const int nr3 = ne03 / ne3;
1626
17- const int nr0 = ( int ) (ne00 / ne0) ;
18- const int nr1 = ( int ) (ne01 / ne1) ;
19- const int nr2 = ( int ) (ne02 / ne2) ;
20- const int nr3 = ( int ) (ne03 / ne3) ;
27+ const int nb0 = dst-> src [ 0 ]-> nb [ 0 ] ;
28+ const int nb1 = dst-> src [ 0 ]-> nb [ 1 ] ;
29+ const int nb2 = dst-> src [ 0 ]-> nb [ 2 ] ;
30+ const int nb3 = dst-> src [ 0 ]-> nb [ 3 ] ;
2131
22- const size_t total = ne0 * ne1 * ne2 * ne3;
23- const int BLOCK_SIZE = 256 ;
24- const int num_blocks = (total + BLOCK_SIZE - 1 ) / BLOCK_SIZE;
32+ const char * base = (const char *) src0_dd;
33+
34+ const size_t total = (size_t ) ne0 * ne1 * ne2 * ne3;
35+ constexpr int BLOCK_SIZE = 256 ;
36+ const int num_blocks = (total + BLOCK_SIZE - 1 ) / BLOCK_SIZE;
37+
38+ const float inv_ne0 = 1 .0f / ne0;
39+ const float inv_ne_01 = 1 .0f / (ne0 * ne1);
40+ const float inv_ne_012 = 1 .0f / (ne0 * ne1 * ne2);
41+ const int repeat_count = nr0 * nr1 * nr2 * nr3;
2542
2643 queue_ptr stream = ctx.stream ();
2744
@@ -33,24 +50,27 @@ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst
3350 return ;
3451 }
3552
36- const int i0 = i % ne0 ;
37- const int i1 = (i / ne0) % ne1 ;
38- const int i2 = (i / (ne0 * ne1)) % ne2 ;
39- const int i3 = i / (ne0 * ne1 * ne2) ;
53+ const int i3 = ( int ) (i * inv_ne_012) ;
54+ const int i2 = (int ) (i * inv_ne_01) - i3 * ne2 ;
55+ const int i1 = (int ) (i * inv_ne0) - ( int ) (i * inv_ne_01) * ne1 ;
56+ const int i0 = i - ( int ) (i * inv_ne0) * ne0 ;
4057
58+ int j0 = 0 , j1 = 0 , j2 = 0 , j3 = 0 ;
4159 float acc = 0 .0f ;
4260
43- for (int j3 = 0 ; j3 < nr3; ++j3) {
44- for (int j2 = 0 ; j2 < nr2; ++j2) {
45- for (int j1 = 0 ; j1 < nr1; ++j1) {
46- for (int j0 = 0 ; j0 < nr0; ++j0) {
47- acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
48- (i3 + j3 * ne3) * ne00 * ne01 * ne02];
49- }
50- }
51- }
52- }
61+ for (int j = 0 ; j < repeat_count; ++j) {
62+ const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
63+ (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
64+ acc += *ptr;
5365
66+ int carry = (++j0 >= nr0);
67+ j0 -= carry * nr0;
68+ carry = (carry && (++j1 >= nr1));
69+ j1 -= carry * nr1;
70+ carry = (carry && (++j2 >= nr2));
71+ j2 -= carry * nr2;
72+ j3 += carry;
73+ }
5474 dst_dd[i] = acc;
5575 });
5676}
0 commit comments