11#include " binbcast.hpp"
22
3+ #include < array>
34#include < cstddef>
45#include < cstdint>
6+ #include < cstdio>
57#include < sycl/sycl.hpp>
8+ #include < utility>
69
10+ #include " dpct/helper.hpp"
711#include " ggml.h"
812
9- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
10- static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
11- int ne0, int ne1, int ne2, int ne3,
12- int ne10, int ne11, int ne12, int ne13,
13- /* int s0, */ int s1, int s2, int s3,
14- /* int s00,*/ int s01, int s02, int s03,
15- /* int s10,*/ int s11, int s12, int s13,
16- const sycl::nd_item<3 > &item_ct1) {
17- const int i0s = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
18- item_ct1.get_local_id (2 );
19- const int i1 = (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
20- item_ct1.get_local_id (1 ));
21- const int i2 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
22- item_ct1.get_local_id (0 )) /
23- ne3;
24- const int i3 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
25- item_ct1.get_local_id (0 )) %
26- ne3;
27-
28- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29- return ;
30- }
31-
32- const int i11 = i1 % ne11;
33- const int i12 = i2 % ne12;
34- const int i13 = i3 % ne13;
35-
36- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39-
40- const src0_t * src0_row = src0 + i_src0;
41- const src1_t * src1_row = src1 + i_src1;
42- dst_t * dst_row = dst + i_dst;
43-
44- for (int i0 = i0s; i0 < ne0;
45- i0 += item_ct1.get_local_range (2 ) * item_ct1.get_group_range (2 )) {
46- const int i10 = i0 % ne10;
47- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
48- }
49- }
50-
51- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
52- static void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
53- int ne0, int ne1, int ne2, int ne3,
54- int ne10, int ne11, int ne12, int ne13,
55- /* int s0, */ int s1, int s2, int s3,
56- /* int s00,*/ int s01, int s02, int s03,
57- /* int s10,*/ int s11, int s12, int s13,
58- const sycl::nd_item<3 > &item_ct1) {
59-
60- const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
61- item_ct1.get_local_id (2 );
62-
63- const int i3 = i/(ne2*ne1*ne0);
64- const int i2 = (i/(ne1*ne0)) % ne2;
65- const int i1 = (i/ne0) % ne1;
66- const int i0 = i % ne0;
67-
68- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69- return ;
13+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
14+ static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3,
15+ int ne10, int ne11, int ne12, int ne13,
16+ /* int s0, */ int s1, int s2, int s3,
17+ /* int s00,*/ int s01, int s02, int s03,
18+ /* int s10,*/ int s11, int s12, int s13, std::size_t num_dst_elements,
19+ const sycl::nd_item<1 > & item_ct1) {
20+ auto calculate_logical_index =
21+ [](const std::array<int , 4 > & dims, std::size_t element_id) __attribute__ ((always_inline))->std ::array<int , 4 > {
22+ std::array<int , 4 > logical_index;
23+ for (int i = 3 ; i >= 0 ; i--) {
24+ logical_index[i] = element_id % dims[i];
25+ element_id /= dims[i];
26+ }
27+ return logical_index;
28+ };
29+
30+ auto calculate_index = [](const std::array<int , 4 > & dims, const std::array<int , 4 > & strides,
31+ const std::array<int , 4 > & indices) __attribute__ ((always_inline))
32+ ->std ::size_t {
33+ std::size_t index = 0 ;
34+ for (int i = 0 ; i < 4 ; i++) {
35+ auto index_i = indices[i];
36+ if (indices[i] >= dims[i]) {
37+ index_i = indices[i] % dims[i];
38+ }
39+ index += strides[i] * index_i;
40+ }
41+ return index;
42+ };
43+
44+ auto element_id = item_ct1.get_global_id (0 );
45+ for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range (0 )) {
46+ auto logical_index = calculate_logical_index ({ ne3, ne2, ne1, ne0 }, element_id);
47+ // The inner most stride is always assumed to be 1 (as s0 is commented out).
48+ auto src_0_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, 1 }, logical_index);
49+ auto src_1_index = calculate_index ({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, 1 }, logical_index);
50+ auto dst_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, 1 }, logical_index);
51+ dst[dst_index] = bin_op (src0[src_0_index], src1[src_1_index]);
7052 }
71-
72- const int i11 = i1 % ne11;
73- const int i12 = i2 % ne12;
74- const int i13 = i3 % ne13;
75-
76- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79-
80- const src0_t * src0_row = src0 + i_src0;
81- const src1_t * src1_row = src1 + i_src1;
82- dst_t * dst_row = dst + i_dst;
83-
84- const int i10 = i0 % ne10;
85- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
8653}
8754
88-
89- template <float (*bin_op)(const float , const float )>
90- struct bin_bcast_sycl {
55+ template <float (*bin_op)(const float , const float )> struct bin_bcast_sycl {
9156 template <typename src0_t , typename src1_t , typename dst_t >
9257 void operator ()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
9358 const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
9459 const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
9560 const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
9661 const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
97- const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
98- const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99- int nr0 = ne10 / ne0;
100- int nr1 = ne11/ne1;
101- int nr2 = ne12/ne2;
102- int nr3 = ne13/ne3;
103-
104- int nr[4 ] = { nr0, nr1, nr2, nr3 };
105-
106- // collapse dimensions until first broadcast dimension
107- int64_t cne[] = {ne0, ne1, ne2, ne3};
108- int64_t cne0[] = {ne00, ne01, ne02, ne03};
109- int64_t cne1[] = {ne10, ne11, ne12, ne13};
110- size_t cnb[] = {nb0, nb1, nb2, nb3};
111- size_t cnb0[] = {nb00, nb01, nb02, nb03};
112- size_t cnb1[] = {nb10, nb11, nb12, nb13};
113- auto collapse = [](int64_t cne[]) {
114- cne[0 ] *= cne[1 ];
115- cne[1 ] = cne[2 ];
116- cne[2 ] = cne[3 ];
117- cne[3 ] = 1 ;
118- };
119-
120- auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121- cnb[1 ] *= cne[1 ];
122- cnb[2 ] *= cne[2 ];
123- cnb[3 ] *= cne[3 ];
124- };
125-
126- if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
127- for (int i = 0 ; i < 4 ; i++) {
128- if (nr[i] != 1 ) {
129- break ;
130- }
131- if (i > 0 ) {
132- collapse_nb (cnb, cne);
133- collapse_nb (cnb0, cne0);
134- collapse_nb (cnb1, cne1);
135- collapse (cne);
136- collapse (cne0);
137- collapse (cne1);
138- }
139- }
140- }
141- {
142- int64_t ne0 = cne[0 ];
143- int64_t ne1 = cne[1 ];
144- int64_t ne2 = cne[2 ];
145- int64_t ne3 = cne[3 ];
146-
147- int64_t ne10 = cne1[0 ];
148- int64_t ne11 = cne1[1 ];
149- int64_t ne12 = cne1[2 ];
150- int64_t ne13 = cne1[3 ];
151-
152- size_t nb0 = cnb[0 ];
153- size_t nb1 = cnb[1 ];
154- size_t nb2 = cnb[2 ];
155- size_t nb3 = cnb[3 ];
156-
157- size_t nb00 = cnb0[0 ];
158- size_t nb01 = cnb0[1 ];
159- size_t nb02 = cnb0[2 ];
160- size_t nb03 = cnb0[3 ];
161-
162- size_t nb10 = cnb1[0 ];
163- size_t nb11 = cnb1[1 ];
164- size_t nb12 = cnb1[2 ];
165- size_t nb13 = cnb1[3 ];
166-
167- size_t s0 = nb0 / sizeof (dst_t );
168- size_t s1 = nb1 / sizeof (dst_t );
169- size_t s2 = nb2 / sizeof (dst_t );
170- size_t s3 = nb3 / sizeof (dst_t );
171-
172- size_t s10 = nb10 / sizeof (src1_t );
173- size_t s11 = nb11 / sizeof (src1_t );
174- size_t s12 = nb12 / sizeof (src1_t );
175- size_t s13 = nb13 / sizeof (src1_t );
176-
177- size_t s00 = nb00 / sizeof (src0_t );
178- size_t s01 = nb01 / sizeof (src0_t );
179- size_t s02 = nb02 / sizeof (src0_t );
180- size_t s03 = nb03 / sizeof (src0_t );
181-
182- GGML_UNUSED (s00);
183-
184- GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
185- GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
186- GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
187- GGML_ASSERT (nb3 % sizeof (dst_t ) == 0 );
188-
189- GGML_ASSERT (nb00 % sizeof (src0_t ) == 0 );
190- GGML_ASSERT (nb01 % sizeof (src0_t ) == 0 );
191- GGML_ASSERT (nb02 % sizeof (src0_t ) == 0 );
192- GGML_ASSERT (nb03 % sizeof (src0_t ) == 0 );
193-
194- GGML_ASSERT (nb10 % sizeof (src1_t ) == 0 );
195- GGML_ASSERT (nb11 % sizeof (src1_t ) == 0 );
196- GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
197- GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
198-
199- GGML_ASSERT (s0 == 1 );
200- GGML_ASSERT (s10 == 1 );
201-
202- const int block_size = 128 ;
203-
204- int64_t hne0 = std::max (ne0/2LL , 1LL );
205-
206- sycl::range<3 > block_dims (1 , 1 , 1 );
207- block_dims[2 ] = std::min<unsigned int >(hne0, block_size);
208- block_dims[1 ] = std::min<unsigned int >(
209- ne1, block_size / (unsigned int )block_dims[2 ]);
210- block_dims[0 ] = std::min (
211- std::min<unsigned int >(
212- ne2 * ne3, block_size / (unsigned int )block_dims[2 ] /
213- (unsigned int )block_dims[1 ]),
214- 64U );
215-
216- sycl::range<3 > block_nums (
217- (ne2 * ne3 + block_dims[0 ] - 1 ) / block_dims[0 ],
218- (ne1 + block_dims[1 ] - 1 ) / block_dims[1 ],
219- (hne0 + block_dims[2 ] - 1 ) / block_dims[2 ]);
220-
221- if (block_nums[0 ] > 65535 ) {
222- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1 ) / block_size;
224- {
225- dpct::has_capability_or_fail (stream->get_device (),
226- {sycl::aspect::fp16});
227-
228- stream->parallel_for (
229- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , block_num) *
230- sycl::range<3 >(1 , 1 , block_size),
231- sycl::range<3 >(1 , 1 , block_size)),
232- [=](sycl::nd_item<3 > item_ct1) {
233- k_bin_bcast_unravel<bin_op>(
234- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235- ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236- s03, s11, s12, s13, item_ct1);
237- });
238- }
239- } else {
240- /*
241- DPCT1049:16: The work-group size passed to the SYCL kernel may
242- exceed the limit. To get the device limit, query
243- info::device::max_work_group_size. Adjust the work-group size if
244- needed.
245- */
246- dpct::has_capability_or_fail (stream->get_device (),
247- {sycl::aspect::fp16});
248-
249- stream->parallel_for (
250- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
251- [=](sycl::nd_item<3 > item_ct1) {
252- k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253- ne2, ne3, ne10, ne11, ne12, ne13,
254- s1, s2, s3, s01, s02, s03, s11, s12, s13,
255- item_ct1);
256- });
257- }
258- }
62+ const size_t nb1, const size_t nb2, const size_t nb3, queue_ptr stream) {
63+ // dst strides in number of elements
64+ size_t s0 = nb0 / sizeof (dst_t );
65+ size_t s1 = nb1 / sizeof (dst_t );
66+ size_t s2 = nb2 / sizeof (dst_t );
67+ size_t s3 = nb3 / sizeof (dst_t );
68+
69+ // src1 strides in number of elements
70+ size_t s10 = nb10 / sizeof (src1_t );
71+ size_t s11 = nb11 / sizeof (src1_t );
72+ size_t s12 = nb12 / sizeof (src1_t );
73+ size_t s13 = nb13 / sizeof (src1_t );
74+
75+ // src0 strides in number of elements
76+ size_t s00 = nb00 / sizeof (src0_t );
77+ size_t s01 = nb01 / sizeof (src0_t );
78+ size_t s02 = nb02 / sizeof (src0_t );
79+ size_t s03 = nb03 / sizeof (src0_t );
80+
81+ std::size_t num_dst_elements = static_cast <std::size_t >(ne0) * static_cast <std::size_t >(ne1) *
82+ static_cast <std::size_t >(ne2) * static_cast <std::size_t >(ne3);
83+ std::size_t local_range = 256 ;
84+ std::size_t num_elements = ne1 * ne2 * ne3 * ne3;
85+ std::size_t global_range = ((num_elements + local_range - 1 ) / local_range) * local_range;
86+ stream->submit ([&](sycl::handler & cgh) {
87+ cgh.parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }), [=](sycl::nd_item<1 > it) {
88+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s1, s2, s3,
89+ s01, s02, s03, s11, s12, s13, num_dst_elements, it);
90+ });
91+ });
25992 }
26093};
26194
@@ -268,24 +101,23 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
268101 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
269102 op ()((const float *) src0->data , (const float *) src1->data , (float *) dst->data , ne00, ne01, ne02, ne03, ne10,
270103 ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
271- ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
104+ main_stream);
272105 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
273106 op ()((const sycl::half *) src0->data , (const sycl::half *) src1->data , (sycl::half *) dst->data , ne00, ne01,
274107 ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
275- nb0, nb1, nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst),
276- main_stream);
108+ nb0, nb1, nb2, nb3, main_stream);
277109 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
278110 op ()((const sycl::half *) src0->data , (const float *) src1->data , (sycl::half *) dst->data , ne00, ne01, ne02,
279111 ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
280- nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
112+ nb2, nb3, main_stream);
281113 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
282114 op ()((const int32_t *) src0->data , (const int32_t *) src1->data , (int32_t *) dst->data , ne00, ne01, ne02, ne03,
283115 ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
284- nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
116+ nb3, main_stream);
285117 } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
286118 op ()((const int16_t *) src0->data , (const int16_t *) src1->data , (int16_t *) dst->data , ne00, ne01, ne02, ne03,
287119 ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
288- nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
120+ nb3, main_stream);
289121 } else {
290122 fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__, ggml_type_name (dst->type ),
291123 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments