11#include " set_rows.hpp"
22
3- typedef void (*set_rows_kernel_t )(const char * src, char * dst);
4-
5- static void set_rows_1_f32_f32 (const char * src, char * dst) {
6- const float * src_f = (const float *) src;
7- float * dst_f = (float *) dst;
8- *dst_f = *src_f;
9- }
10-
11- static void set_rows_1_f32_f16 (const char * src, char * dst) {
12- const float * src_f = (const float *) src;
13- sycl::half * dst_h = (sycl::half *) dst;
14- *dst_h = sycl::vec<float , 1 >(*src_f).convert <sycl::half, sycl::rounding_mode::automatic>()[0 ];
3+ template <typename TIn, typename TOut>
4+ static inline void convert (const char * src, char * dst) {
5+ auto src_val = *reinterpret_cast <const TIn*>(src);
6+ auto dst_val = sycl::vec<TIn, 1 >(src_val).template convert <TOut>()[0 ];
7+ *reinterpret_cast <TOut*>(dst) = dst_val;
158}
169
17- template <set_rows_kernel_t set_rows_1 >
10+ template <typename TIn, typename TOut >
1811static void k_set_rows (
1912 const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
2013 const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
@@ -38,18 +31,17 @@ static void k_set_rows(
3831
3932 const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3 >({nb10, nb11, nb12}, {i10, i11, i12}));
4033
41-
4234 const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
4335 char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
4436
4537 for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
4638 const char * src_elem = src0_row + col * src_type_size;
4739 char * dst_elem = dst_row_ptr + col * dst_type_size;
48- set_rows_1 (src_elem, dst_elem);
40+ convert<TIn, TOut> (src_elem, dst_elem);
4941 }
5042}
5143
52- template <set_rows_kernel_t set_rows_1 >
44+ template <typename TIn, typename TOut >
5345static void set_rows_sycl (
5446 const char * src0_d, const int64_t * src1_d, char * dst_d,
5547 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -73,7 +65,7 @@ static void set_rows_sycl(
7365 stream,
7466 sycl::nd_range<3 >(grid_size * block_size, block_size),
7567 [=](sycl::nd_item<3 > item_ct1) {
76- k_set_rows<set_rows_1 >(
68+ k_set_rows<TIn, TOut >(
7769 src0_d, src1_d, dst_d,
7870 ne00, ne01, ne11, ne12,
7971 nb01, nb02, nb03,
@@ -103,7 +95,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10395 dpct::queue_ptr stream = ctx.stream ();
10496 switch (dst->type ) {
10597 case GGML_TYPE_F32:
106- set_rows_sycl<set_rows_1_f32_f32 >(
98+ set_rows_sycl<float , float >(
10799 (const char *)dst->src [0 ]->data , src1_dd, (char *)dst->data ,
108100 ne00, ne01, ne02, ne03,
109101 ne11, ne12,
@@ -116,7 +108,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
116108 break ;
117109 case GGML_TYPE_F16:
118110 dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
119- set_rows_sycl<set_rows_1_f32_f16 >(
111+ set_rows_sycl<float , sycl::half >(
120112 (const char *)dst->src [0 ]->data , src1_dd, (char *)dst->data ,
121113 ne00, ne01, ne02, ne03,
122114 ne11, ne12,
0 commit comments