@@ -36,25 +36,16 @@ static void k_set_rows(
3636 const int i11 = i02 % ne11;
3737 const int i10 = i01;
3838
39- const int64_t dst_row = *(const int64_t *)((const char *)src1 + i10* nb10 + i11* nb11 + i12*nb12 );
39+ const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset< 3 >({ nb10, nb11, nb12}, {i10, i11, i12}) );
4040
41- const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
41+
42+ const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
4243 char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
43- // Optimize for same-type operations: use collective memory copy
44- if (src_type_size == dst_type_size) {
45- // All threads in the work-group cooperatively copy the row
46- const size_t row_bytes = ne00 * src_type_size;
47- // Each thread copies a chunk of the row
48- for (size_t byte_idx = item_ct1.get_local_id (0 ); byte_idx < row_bytes; byte_idx += item_ct1.get_local_range (0 )) {
49- dst_row_ptr[byte_idx] = src0_row[byte_idx];
50- }
51- } else {
52- // Type conversion required, use element-wise approach
53- for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
54- const char * src_elem = src0_row + col * src_type_size;
55- char * dst_elem = dst_row_ptr + col * dst_type_size;
56- set_rows_1 (src_elem, dst_elem);
57- }
44+
45+ for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
46+ const char * src_elem = src0_row + col * src_type_size;
47+ char * dst_elem = dst_row_ptr + col * dst_type_size;
48+ set_rows_1 (src_elem, dst_elem);
5849 }
5950}
6051
@@ -68,10 +59,10 @@ static void set_rows_sycl(
6859 const size_t src_type_size, const size_t dst_type_size,
6960 queue_ptr stream) {
7061
71- const int max_threads_per_row = 256 ; // KEEPING 256 for now
62+ constexpr int max_threads_per_row = 64 ; // KEEPING 64 for now
7263 const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
7364
74- const int max_threads_per_block = 256 ;
65+ constexpr int max_threads_per_block = 64 ;
7566 const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
7667
7768 const sycl::range<3 > block_size (1 , rows_per_block, threads_per_row);
0 commit comments