Skip to content

Commit 1ea57a8

Browse files
committed
binbcast: use void pointer to prevent intermediate type conversions
1 parent 0cb2933 commit 1ea57a8

File tree

3 files changed

+28
-35
lines changed

3 files changed

+28
-35
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
508508
template<float (*bin_op)(const float, const float)>
509509
struct bin_bcast_sycl {
510510
template <typename src0_t, typename src1_t, typename dst_t>
511-
void operator()(ggml_backend_sycl_context & ctx,
512-
const struct ggml_tensor *src0,
511+
void operator()(const struct ggml_tensor *src0,
513512
const struct ggml_tensor *src1, struct ggml_tensor *dst,
514513
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
515514
queue_ptr stream) {
@@ -643,30 +642,29 @@ struct bin_bcast_sycl {
643642
});
644643
}
645644
}
646-
GGML_UNUSED(ctx);
647645
}
648646
};
649647

650648
template <class op>
651-
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
649+
inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0,
652650
const ggml_tensor *src1, ggml_tensor *dst,
653-
const float *src0_dd, const float *src1_dd,
654-
float *dst_dd,
651+
const void *src0_dd, const void *src1_dd,
652+
void *dst_dd,
655653
const queue_ptr &main_stream) {
656654

657655
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
658-
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
656+
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, main_stream);
659657
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
660-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
658+
op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd,
661659
(sycl::half *)dst_dd, main_stream);
662660
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
663-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
661+
op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd, (float *)dst_dd,
664662
main_stream);
665663
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
666-
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
664+
op()(src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
667665
main_stream);
668666
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
669-
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
667+
op()(src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
670668
main_stream);
671669
} else {
672670
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -756,43 +756,39 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
756756

757757
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx,
758758
ggml_tensor *dst) {
759-
// TODO: remove duplicate variables
760-
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
761-
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
762-
float * dst_dd = static_cast<float *>(dst->data);
759+
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
760+
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
761+
void * dst_dd = static_cast<void *>(dst->data);
763762
const dpct::queue_ptr main_stream = ctx.stream();
764763

765-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
764+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
766765
}
767766

768767
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
769-
// TODO: remove duplicate variables
770-
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
771-
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
772-
float * dst_dd = static_cast<float *>(dst->data);
768+
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
769+
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
770+
void * dst_dd = static_cast<void *>(dst->data);
773771
const dpct::queue_ptr main_stream = ctx.stream();
774772

775-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
773+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
776774
}
777775

778776
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
779-
// TODO: remove duplicate variables
780-
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
781-
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
782-
float * dst_dd = static_cast<float *>(dst->data);
777+
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
778+
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
779+
void * dst_dd = static_cast<void *>(dst->data);
783780
const dpct::queue_ptr main_stream = ctx.stream();
784781

785-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
782+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
786783
}
787784

788785
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
789-
// TODO: remove duplicate variables
790-
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
791-
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
792-
float * dst_dd = static_cast<float *>(dst->data);
786+
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
787+
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
788+
void * dst_dd = static_cast<void *>(dst->data);
793789
const dpct::queue_ptr main_stream = ctx.stream();
794790

795-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
791+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
796792
}
797793

798794

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,12 +2534,11 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor
25342534

25352535

25362536
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2537-
// TODO: remove duplicate variables
2538-
const float * src0_d = static_cast<float *>(dst->src[0]->data);
2539-
float * dst_d = static_cast<float *>(dst->data);
2537+
const void * src0_d = static_cast<void *>(dst->src[0]->data);
2538+
void * dst_d = static_cast<void *>(dst->data);
25402539
dpct::queue_ptr main_stream = ctx.stream();
25412540

2542-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
2541+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
25432542
}
25442543

25452544

0 commit comments

Comments
 (0)