1+ #include " binbcast.hpp"
2+ #include " common.hpp"
3+
4+ static __dpct_inline__ float op_repeat (const float a, const float b) {
5+ return b;
6+ GGML_UNUSED (a);
7+ }
8+
9+ static __dpct_inline__ float op_add (const float a, const float b) {
10+ return a + b;
11+ }
12+
13+ static __dpct_inline__ float op_sub (const float a, const float b) {
14+ return a - b;
15+ }
16+
17+ static __dpct_inline__ float op_mul (const float a, const float b) {
18+ return a * b;
19+ }
20+
21+ static __dpct_inline__ float op_div (const float a, const float b) {
22+ return a / b;
23+ }
24+
25+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
26+ static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3,
27+ int ne10, int ne11, int ne12, int ne13,
28+ /* int s0, */ int s1, int s2, int s3,
29+ /* int s10,*/ int s11, int s12, int s13, const sycl::nd_item<3 > & item_ct1) {
30+ const int i0s = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
31+ const int i1 = (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
32+ const int i2 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) + item_ct1.get_local_id (0 )) / ne3;
33+ const int i3 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) + item_ct1.get_local_id (0 )) % ne3;
34+
35+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
36+ return ;
37+ }
38+
39+ const int i11 = i1 % ne11;
40+ const int i12 = i2 % ne12;
41+ const int i13 = i3 % ne13;
42+
43+ const size_t i_src0 = i3 * s3 + i2 * s2 + i1 * s1;
44+ const size_t i_src1 = i13 * s13 + i12 * s12 + i11 * s11;
45+ const size_t i_dst = i_src0;
46+
47+ const src0_t * src0_row = src0 + i_src0;
48+ const src1_t * src1_row = src1 + i_src1;
49+ dst_t * dst_row = dst + i_dst;
50+
51+ for (int i0 = i0s; i0 < ne0; i0 += item_ct1.get_local_range (2 ) * item_ct1.get_group_range (2 )) {
52+ const int i10 = i0 % ne10;
53+ dst_row[i0] = (dst_t ) bin_op (src0 ? (float ) src0_row[i0] : 0 .0f , (float ) src1_row[i10]);
54+ }
55+ }
56+
57+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
58+ static void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2,
59+ int ne3, int ne10, int ne11, int ne12, int ne13,
60+ /* int s0, */ int s1, int s2, int s3,
61+ /* int s10,*/ int s11, int s12, int s13, const sycl::nd_item<3 > & item_ct1) {
62+ const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
63+
64+ const int i3 = i / (ne2 * ne1 * ne0);
65+ const int i2 = (i / (ne1 * ne0)) % ne2;
66+ const int i1 = (i / ne0) % ne1;
67+ const int i0 = i % ne0;
68+
69+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
70+ return ;
71+ }
72+
73+ const int i11 = i1 % ne11;
74+ const int i12 = i2 % ne12;
75+ const int i13 = i3 % ne13;
76+
77+ const size_t i_src0 = i3 * s3 + i2 * s2 + i1 * s1;
78+ const size_t i_src1 = i13 * s13 + i12 * s12 + i11 * s11;
79+ const size_t i_dst = i_src0;
80+
81+ const src0_t * src0_row = src0 + i_src0;
82+ const src1_t * src1_row = src1 + i_src1;
83+ dst_t * dst_row = dst + i_dst;
84+
85+ const int i10 = i0 % ne10;
86+ dst_row[i0] = (dst_t ) bin_op (src0 ? (float ) src0_row[i0] : 0 .0f , (float ) src1_row[i10]);
87+ }
88+
89+ template <float (*bin_op)(const float , const float )> struct bin_bcast_sycl {
90+ template <typename src0_t , typename src1_t , typename dst_t >
91+ void operator ()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
92+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, queue_ptr stream) {
93+ GGML_TENSOR_BINARY_OP_LOCALS
94+
95+ int nr0 = ne10 / ne0;
96+ int nr1 = ne11 / ne1;
97+ int nr2 = ne12 / ne2;
98+ int nr3 = ne13 / ne3;
99+
100+ int nr[4 ] = { nr0, nr1, nr2, nr3 };
101+
102+ // collapse dimensions until first broadcast dimension
103+ int64_t cne0[] = { ne0, ne1, ne2, ne3 };
104+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
105+ size_t cnb0[] = { nb0, nb1, nb2, nb3 };
106+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
107+ auto collapse = [](int64_t cne[]) {
108+ cne[0 ] *= cne[1 ];
109+ cne[1 ] = cne[2 ];
110+ cne[2 ] = cne[3 ];
111+ cne[3 ] = 1 ;
112+ };
113+
114+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
115+ cnb[1 ] *= cne[1 ];
116+ cnb[2 ] *= cne[2 ];
117+ cnb[3 ] *= cne[3 ];
118+ };
119+
120+ for (int i = 0 ; i < 4 ; i++) {
121+ if (nr[i] != 1 ) {
122+ break ;
123+ }
124+ if (i > 0 ) {
125+ collapse_nb (cnb0, cne0);
126+ collapse_nb (cnb1, cne1);
127+ collapse (cne0);
128+ collapse (cne1);
129+ }
130+ }
131+ {
132+ int64_t ne0 = cne0[0 ];
133+ int64_t ne1 = cne0[1 ];
134+ int64_t ne2 = cne0[2 ];
135+ int64_t ne3 = cne0[3 ];
136+
137+ int64_t ne10 = cne1[0 ];
138+ int64_t ne11 = cne1[1 ];
139+ int64_t ne12 = cne1[2 ];
140+ int64_t ne13 = cne1[3 ];
141+
142+ size_t nb0 = cnb0[0 ];
143+ size_t nb1 = cnb0[1 ];
144+ size_t nb2 = cnb0[2 ];
145+ size_t nb3 = cnb0[3 ];
146+
147+ size_t nb10 = cnb1[0 ];
148+ size_t nb11 = cnb1[1 ];
149+ size_t nb12 = cnb1[2 ];
150+ size_t nb13 = cnb1[3 ];
151+
152+ size_t s0 = nb0 / sizeof (dst_t );
153+ size_t s1 = nb1 / sizeof (dst_t );
154+ size_t s2 = nb2 / sizeof (dst_t );
155+ size_t s3 = nb3 / sizeof (dst_t );
156+
157+ size_t s10 = nb10 / sizeof (src1_t );
158+ size_t s11 = nb11 / sizeof (src1_t );
159+ size_t s12 = nb12 / sizeof (src1_t );
160+ size_t s13 = nb13 / sizeof (src1_t );
161+
162+ GGML_ASSERT (s0 == 1 );
163+ GGML_ASSERT (s10 == 1 );
164+
165+ const int block_size = 128 ;
166+
167+ int64_t hne0 = std::max (ne0 / 2LL , 1LL );
168+
169+ sycl::range<3 > block_dims (1 , 1 , 1 );
170+ block_dims[2 ] = std::min<unsigned int >(hne0, block_size);
171+ block_dims[1 ] = std::min<unsigned int >(ne1, block_size / (unsigned int ) block_dims[2 ]);
172+ block_dims[0 ] = std::min (std::min<unsigned int >(ne2 * ne3, block_size / (unsigned int ) block_dims[2 ] /
173+ (unsigned int ) block_dims[1 ]),
174+ 64U );
175+
176+ sycl::range<3 > block_nums ((ne2 * ne3 + block_dims[0 ] - 1 ) / block_dims[0 ],
177+ (ne1 + block_dims[1 ] - 1 ) / block_dims[1 ],
178+ (hne0 + block_dims[2 ] - 1 ) / block_dims[2 ]);
179+
180+ if (block_nums[0 ] > 65535 ) {
181+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
182+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
183+ {
184+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
185+
186+ stream->parallel_for (
187+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , block_num) * sycl::range<3 >(1 , 1 , block_size),
188+ sycl::range<3 >(1 , 1 , block_size)),
189+ [=](sycl::nd_item<3 > item_ct1) {
190+ k_bin_bcast_unravel<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12,
191+ ne13, s1, s2, s3, s11, s12, s13, item_ct1);
192+ });
193+ }
194+ } else {
195+ /*
196+ DPCT1049:16: The work-group size passed to the SYCL kernel may
197+ exceed the limit. To get the device limit, query
198+ info::device::max_work_group_size. Adjust the work-group size if
199+ needed.
200+ */
201+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
202+
203+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims),
204+ [=](sycl::nd_item<3 > item_ct1) {
205+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11,
206+ ne12, ne13, s1, s2, s3, s11, s12, s13, item_ct1);
207+ });
208+ }
209+ }
210+ }
211+ };
212+
213+ template <class op >
214+ inline void ggml_sycl_op_bin_bcast (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
215+ const void * src0_dd, const void * src1_dd, void * dst_dd,
216+ const queue_ptr & main_stream) {
217+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
218+ op ()(src0, src1, dst, (const float *) src0_dd, (const float *) src1_dd, (float *) dst_dd, main_stream);
219+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
220+ op ()(src0, src1, dst, (const sycl::half *) src0_dd, (const float *) src1_dd, (sycl::half *) dst_dd,
221+ main_stream);
222+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
223+ op ()(src0, src1, dst, (const sycl::half *) src0_dd, (const float *) src1_dd, (float *) dst_dd, main_stream);
224+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
225+ op ()(src0, src1, dst, (const int32_t *) src0_dd, (const int32_t *) src1_dd, (int32_t *) dst_dd, main_stream);
226+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
227+ op ()(src0, src1, dst, (const int16_t *) src0_dd, (const int16_t *) src1_dd, (int16_t *) dst_dd, main_stream);
228+ } else {
229+ fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__, ggml_type_name (dst->type ),
230+ ggml_type_name (src0->type ), ggml_type_name (src1->type ));
231+ GGML_ABORT (" fatal error" );
232+ }
233+ }
234+
235+ inline void ggml_sycl_op_add (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
236+ const void * src0_dd = static_cast <void *>(dst->src [0 ]->data );
237+ const void * src1_dd = static_cast <void *>(dst->src [1 ]->data );
238+ void * dst_dd = static_cast <void *>(dst->data );
239+ const dpct::queue_ptr main_stream = ctx.stream ();
240+
241+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src [0 ], dst->src [1 ], dst, src0_dd, src1_dd, dst_dd,
242+ main_stream);
243+ }
244+
245+ inline void ggml_sycl_op_sub (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
246+ const void * src0_dd = static_cast <void *>(dst->src [0 ]->data );
247+ const void * src1_dd = static_cast <void *>(dst->src [1 ]->data );
248+ void * dst_dd = static_cast <void *>(dst->data );
249+ const dpct::queue_ptr main_stream = ctx.stream ();
250+
251+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src [0 ], dst->src [1 ], dst, src0_dd, src1_dd, dst_dd,
252+ main_stream);
253+ }
254+
255+ inline void ggml_sycl_op_mul (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
256+ const void * src0_dd = static_cast <void *>(dst->src [0 ]->data );
257+ const void * src1_dd = static_cast <void *>(dst->src [1 ]->data );
258+ void * dst_dd = static_cast <void *>(dst->data );
259+ const dpct::queue_ptr main_stream = ctx.stream ();
260+
261+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src [0 ], dst->src [1 ], dst, src0_dd, src1_dd, dst_dd,
262+ main_stream);
263+ }
264+
265+ inline void ggml_sycl_op_div (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
266+ const void * src0_dd = static_cast <void *>(dst->src [0 ]->data );
267+ const void * src1_dd = static_cast <void *>(dst->src [1 ]->data );
268+ void * dst_dd = static_cast <void *>(dst->data );
269+ const dpct::queue_ptr main_stream = ctx.stream ();
270+
271+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src [0 ], dst->src [1 ], dst, src0_dd, src1_dd, dst_dd,
272+ main_stream);
273+ }
274+
275+ inline void ggml_sycl_op_repeat (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
276+ const void * src0_d = static_cast <void *>(dst->src [0 ]->data );
277+ void * dst_d = static_cast <void *>(dst->data );
278+ dpct::queue_ptr main_stream = ctx.stream ();
279+
280+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src [0 ], dst, nullptr , src0_d, dst_d, main_stream);
281+ }
282+
283+ void ggml_sycl_add (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
284+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
285+ ggml_sycl_op_add (ctx, dst);
286+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
287+ }
288+
289+ void ggml_sycl_sub (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
290+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
291+ ggml_sycl_op_sub (ctx, dst);
292+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
293+ }
294+
295+ void ggml_sycl_mul (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
296+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
297+ ggml_sycl_op_mul (ctx, dst);
298+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
299+ }
300+
301+ void ggml_sycl_div (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
302+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
303+ ggml_sycl_op_div (ctx, dst);
304+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
305+ }
306+
307+ void ggml_sycl_repeat (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
308+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
309+ ggml_sycl_op_repeat (ctx, dst);
310+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
311+ }
0 commit comments