@@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() {
66 return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
77}
88}
9+
910template <typename TIn, typename TOut>
1011static inline std::enable_if_t <utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void >
1112convert (const char * src, char * dst) {
1213 auto src_val = *reinterpret_cast <const TIn*>(src);
1314 auto dst_val = sycl::vec<TIn, 1 >(src_val).template convert <TOut, sycl::rounding_mode::automatic>()[0 ];
14- *reinterpret_cast <TOut*>(dst) = dst_val;;
15+ *reinterpret_cast <TOut*>(dst) = dst_val;
1516}
1617
1718template <typename TIn, typename TOut>
1819static void k_set_rows (
1920 const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20- const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
21+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
22+ const int64_t ne11, const int64_t ne12,
2123 const size_t nb01, const size_t nb02, const size_t nb03,
2224 const size_t nb10, const size_t nb11, const size_t nb12,
2325 const size_t nb1, const size_t nb2, const size_t nb3,
2426 const size_t src_type_size, const size_t dst_type_size,
25- const sycl::nd_item<3 > & item_ct1) {
26-
27- const int i03 = item_ct1.get_group (0 );
28- const int i02 = item_ct1.get_group (1 );
29- const int i01 = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 ); // Row index
27+ const int64_t total_elements,
28+ const sycl::nd_item<1 > & item_ct1) {
3029
31- if (i01 >= ne01) {
30+ const int64_t i = item_ct1.get_global_linear_id ();
31+ if (i >= total_elements) {
3232 return ;
3333 }
3434
35- const int i12 = i03 % ne12;
36- const int i11 = i02 % ne11;
37- const int i10 = i01;
35+ const int64_t i03 = i / (ne00 * ne01 * ne02);
36+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
37+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
38+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
39+
40+ const int64_t i12 = i03 % ne12;
41+ const int64_t i11 = i02 % ne11;
42+ const int64_t i10 = i01;
3843
3944 const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3 >({nb10, nb11, nb12}, {i10, i11, i12}));
4045
4146 const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
42- char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
47+ const char * src_elem = src0_row + i00 * src_type_size;
48+ char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
49+ char * dst_elem = dst_row_ptr + i00 * dst_type_size;
4350
44- for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
45- const char * src_elem = src0_row + col * src_type_size;
46- char * dst_elem = dst_row_ptr + col * dst_type_size;
47- convert<TIn, TOut>(src_elem, dst_elem);
48- }
51+ convert<TIn, TOut>(src_elem, dst_elem);
4952}
5053
5154template <typename TIn, typename TOut>
@@ -58,32 +61,29 @@ static void set_rows_sycl(
5861 const size_t src_type_size, const size_t dst_type_size,
5962 queue_ptr stream) {
6063
61- constexpr int max_threads_per_row = 64 ; // KEEPING 64 for now
62- const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
63-
64- constexpr int max_threads_per_block = 64 ;
65- const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
66-
67- const sycl::range<3 > block_size (1 , rows_per_block, threads_per_row);
68- const sycl::range<3 > grid_size (ne03, ne02, (ne01 + rows_per_block - 1 ) / rows_per_block);
69-
70- sycl_parallel_for (
71- stream,
72- sycl::nd_range<3 >(grid_size * block_size, block_size),
73- [=](sycl::nd_item<3 > item_ct1) {
74- k_set_rows<TIn, TOut>(
75- src0_d, src1_d, dst_d,
76- ne00, ne01, ne11, ne12,
77- nb01, nb02, nb03,
78- nb10, nb11, nb12,
79- nb1, nb2, nb3,
80- src_type_size, dst_type_size,
81- item_ct1
82- );
83- }
84- );
85- }
64+ const int64_t total_elements = ne00 * ne01 * ne02 * ne03;
8665
66+ constexpr int block_size = 64 ;
67+ const int64_t grid_size = ceil_div (total_elements, block_size);
68+
69+ sycl_parallel_for (
70+ stream,
71+ sycl::nd_range<1 >(grid_size * block_size, block_size),
72+ [=](sycl::nd_item<1 > item_ct1) {
73+ k_set_rows<TIn, TOut>(
74+ src0_d, src1_d, dst_d,
75+ ne00, ne01, ne02,
76+ ne11, ne12,
77+ nb01, nb02, nb03,
78+ nb10, nb11, nb12,
79+ nb1, nb2, nb3,
80+ src_type_size, dst_type_size,
81+ total_elements,
82+ item_ct1
83+ );
84+ }
85+ );
86+ }
8787
8888void ggml_sycl_op_set_rows (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
8989 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 2 );
@@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
122122 nb1, nb2, nb3,
123123 sizeof (float ), sizeof (sycl::half),
124124 stream
125- );
125+ );
126126 break ;
127127 default :
128128 GGML_ABORT (" Unsupported tensor type!" );
0 commit comments