@@ -508,8 +508,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
508508template <float (*bin_op)(const float , const float )>
509509struct bin_bcast_sycl {
510510 template <typename src0_t , typename src1_t , typename dst_t >
511- void operator ()(ggml_backend_sycl_context & ctx,
512- const struct ggml_tensor *src0,
511+ void operator ()(const struct ggml_tensor *src0,
513512 const struct ggml_tensor *src1, struct ggml_tensor *dst,
514513 const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
515514 queue_ptr stream) {
@@ -643,30 +642,29 @@ struct bin_bcast_sycl {
643642 });
644643 }
645644 }
646- GGML_UNUSED (ctx);
647645 }
648646};
649647
650648template <class op >
651- inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
649+ inline void ggml_sycl_op_bin_bcast (const ggml_tensor *src0,
652650 const ggml_tensor *src1, ggml_tensor *dst,
653- const float *src0_dd, const float *src1_dd,
654- float *dst_dd,
651+ const void *src0_dd, const void *src1_dd,
652+ void *dst_dd,
655653 const queue_ptr &main_stream) {
656654
657655 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
658- op ()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
656+ op ()(src0, src1, dst, ( const float *) src0_dd, ( const float *) src1_dd, ( float *) dst_dd, main_stream);
659657 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
660- op ()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
658+ op ()(src0, src1, dst, (const sycl::half *)src0_dd, ( const float *) src1_dd,
661659 (sycl::half *)dst_dd, main_stream);
662660 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
663- op ()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
661+ op ()(src0, src1, dst, (const sycl::half *)src0_dd, ( const float *) src1_dd, ( float *) dst_dd,
664662 main_stream);
665663 } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
666- op ()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
664+ op ()(src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
667665 main_stream);
668666 } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
669- op ()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
667+ op ()(src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
670668 main_stream);
671669 } else {
672670 fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__,
0 commit comments