11#include " binbcast.hpp"
2+ #include < cstddef>
3+ #include < cstdint>
24#include < sycl/sycl.hpp>
35#include " ggml.h"
46
@@ -85,15 +87,14 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
8587template <float (*bin_op)(const float , const float )>
8688struct bin_bcast_sycl {
8789 template <typename src0_t , typename src1_t , typename dst_t >
88- void operator ()(ggml_backend_sycl_context & ctx,
89- const struct ggml_tensor *src0,
90- const struct ggml_tensor *src1, struct ggml_tensor *dst,
91- const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
92- queue_ptr stream) {
93-
94- GGML_TENSOR_BINARY_OP_LOCALS
95-
96- int nr0 = ne10/ne0;
90+ void operator ()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
91+ const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
92+ const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
93+ const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
94+ const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
95+ const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguos,
96+ const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
97+ int nr0 = ne10 / ne0;
9798 int nr1 = ne11/ne1;
9899 int nr2 = ne12/ne2;
99100 int nr3 = ne13/ne3;
@@ -120,7 +121,7 @@ struct bin_bcast_sycl {
120121 cnb[3 ] *= cne[3 ];
121122 };
122123
123- if (ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && ggml_is_contiguous (dst) ) {
124+ if (src0_is_contiguos && src1_is_contiguous && dst_is_contiguous ) {
124125 for (int i = 0 ; i < 4 ; i++) {
125126 if (nr[i] != 1 ) {
126127 break ;
@@ -253,32 +254,39 @@ struct bin_bcast_sycl {
253254 });
254255 }
255256 }
256- GGML_UNUSED (ctx);
257257 }
258258};
259259
260260template <class op >
261- inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
262- const ggml_tensor *src1, ggml_tensor * dst) {
261+ inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1 ,
262+ ggml_tensor * dst) {
263263 dpct::queue_ptr main_stream = ctx.stream ();
264+ GGML_TENSOR_BINARY_OP_LOCALS
264265
265- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
266- op ()(ctx, src0, src1, dst, (const float *)src0->data , (const float *)src1->data , (float *)dst->data , main_stream);
266+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
267+ op ()((const float *) src0->data , (const float *) src1->data , (float *) dst->data , ne00, ne01, ne02, ne03, ne10,
268+ ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
269+ ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
267270 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
268- op ()(ctx, src0, src1, dst, (const sycl::half *)src0->data , (const sycl::half *)src1->data ,
269- (sycl::half *)dst->data , main_stream);
270- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
271- op ()(ctx, src0, src1, dst, (const sycl::half *)src0->data , (const float *)src1->data , (sycl::half *)dst->data ,
271+ op ()((const sycl::half *) src0->data , (const sycl::half *) src1->data , (sycl::half *) dst->data , ne00, ne01,
272+ ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
273+ nb0, nb1, nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst),
272274 main_stream);
275+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
276+ op ()((const sycl::half *) src0->data , (const float *) src1->data , (sycl::half *) dst->data , ne00, ne01, ne02,
277+ ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
278+ nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
273279 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
274- op ()(ctx, src0, src1, dst, (const int32_t *)src0->data , (const int32_t *)src1->data , (int32_t *)dst->data ,
275- main_stream);
280+ op ()((const int32_t *) src0->data , (const int32_t *) src1->data , (int32_t *) dst->data , ne00, ne01, ne02, ne03,
281+ ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
282+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
276283 } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
277- op ()(ctx, src0, src1, dst, (const int16_t *)src0->data , (const int16_t *)src1->data , (int16_t *)dst->data ,
278- main_stream);
284+ op ()((const int16_t *) src0->data , (const int16_t *) src1->data , (int16_t *) dst->data , ne00, ne01, ne02, ne03,
285+ ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
286+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
279287 } else {
280- fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__,
281- ggml_type_name (dst-> type ), ggml_type_name (src0->type ), ggml_type_name (src1->type ));
288+ fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__, ggml_type_name (dst-> type ),
289+ ggml_type_name (src0->type ), ggml_type_name (src1->type ));
282290 GGML_ABORT (" fatal error" );
283291 }
284292}
0 commit comments