@@ -44,8 +44,8 @@ namespace copy_as_contig
4444
4545template <typename T,
4646 typename IndexerT,
47- std::uint32_t vec_sz = 4u ,
48- std::uint32_t n_vecs = 2u ,
47+ std::uint8_t vec_sz = 4u ,
48+ std::uint8_t n_vecs = 2u ,
4949 bool enable_sg_loadstore = true >
5050class CopyAsCContigFunctor
5151{
@@ -68,25 +68,23 @@ class CopyAsCContigFunctor
6868 {
6969 static_assert (vec_sz > 0 );
7070 static_assert (n_vecs > 0 );
71- static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
7271
73- constexpr std::uint8_t elems_per_wi =
74- static_cast <std::uint8_t >(vec_sz * n_vecs);
72+ constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs;
7573
7674 using dpctl::tensor::type_utils::is_complex;
7775 if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
7876 const std::uint16_t sgSize =
7977 ndit.get_sub_group ().get_local_range ()[0 ];
8078 const std::size_t gid = ndit.get_global_linear_id ();
8179
82- // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
80+ // start = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
8381 // gid % sgSize == gid - (gid / sgSize) * sgSize
84- const std::size_t elems_per_sg = sgSize * ( elems_per_wi - 1 ) ;
85- const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
86- const std:: size_t offset_max =
87- std::min (nelems, base + sgSize * elems_per_wi );
82+ const std::size_t elems_per_sg = sgSize * elems_per_wi;
83+ const std::size_t start =
84+ (gid / sgSize) * (elems_per_sg - sgSize) + gid;
85+ const std:: size_t end = std::min (nelems, start + elems_per_sg );
8886
89- for (size_t offset = base ; offset < offset_max ; offset += sgSize) {
87+ for (size_t offset = start ; offset < end ; offset += sgSize) {
9088 auto src_offset = src_indexer (offset);
9189 dst_p[offset] = src_p[src_offset];
9290 }
@@ -132,8 +130,8 @@ class CopyAsCContigFunctor
132130
133131template <typename T,
134132 typename IndexerT,
135- std::uint32_t vec_sz,
136- std::uint32_t n_vecs,
133+ std::uint8_t vec_sz,
134+ std::uint8_t n_vecs,
137135 bool enable_sg_load,
138136 typename KernelName>
139137sycl::event submit_c_contiguous_copy (sycl::queue &exec_q,
@@ -145,7 +143,6 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
145143{
146144 static_assert (vec_sz > 0 );
147145 static_assert (n_vecs > 0 );
148- static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
149146
150147 constexpr std::size_t preferred_lws = 256 ;
151148
@@ -187,8 +184,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
187184
188185template <typename T,
189186 typename IndexT,
190- std::uint32_t vec_sz,
191- std::uint32_t n_vecs,
187+ std::uint8_t vec_sz,
188+ std::uint8_t n_vecs,
192189 bool enable_sgload>
193190class as_contig_krn ;
194191
@@ -210,8 +207,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
210207 using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
211208 const IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
212209
213- constexpr std::uint32_t vec_sz = 4u ;
214- constexpr std::uint32_t n_vecs = 2u ;
210+ constexpr std::uint8_t vec_sz = 4u ;
211+ constexpr std::uint8_t n_vecs = 2u ;
215212
216213 using dpctl::tensor::kernels::alignment_utils::
217214 disabled_sg_loadstore_wrapper_krn;
@@ -256,8 +253,8 @@ template <typename fnT, typename T> struct AsCContigFactory
256253
257254template <typename T,
258255 typename IndexerT,
259- std::uint32_t tile_size,
260- std::uint32_t n_lines>
256+ std::uint16_t tile_size,
257+ std::uint16_t n_lines>
261258class as_contig_batch_of_square_matrices_krn ;
262259
263260namespace detail
@@ -283,14 +280,14 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
283280 const T *src_tp = reinterpret_cast <const T *>(src_p);
284281 T *dst_tp = reinterpret_cast <T *>(dst_p);
285282
286- constexpr std::uint32_t private_tile_size = 4 ;
287- constexpr std::uint32_t n_lines = 2 ;
288- constexpr std::uint32_t block_size =
283+ constexpr std::uint16_t private_tile_size = 4 ;
284+ constexpr std::uint16_t n_lines = 2 ;
285+ constexpr std::uint16_t block_size =
289286 n_lines * private_tile_size * private_tile_size;
290287
291- constexpr std::uint32_t lws0 = block_size;
292- constexpr std::uint32_t lws1 = n_lines;
293- constexpr std::uint32_t nelems_per_wi = (block_size / lws1);
288+ constexpr std::uint16_t lws0 = block_size;
289+ constexpr std::uint16_t lws1 = n_lines;
290+ constexpr std::uint16_t nelems_per_wi = (block_size / lws1);
294291
295292 static_assert (nelems_per_wi * lws1 == block_size);
296293 static_assert (nelems_per_wi == private_tile_size * private_tile_size);
@@ -377,40 +374,41 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
377374 std::array<T, nelems_per_wi> private_block_01 = {T (0 )};
378375 std::array<T, nelems_per_wi> private_block_10 = {T (0 )};
379376
380- // 0 <= lid_lin < lws0 * lws1 == (block_size * block_size /
381- // nelems_per_wi) == (block_size/private_tile_size)**2
382- constexpr std::uint32_t n_private_tiles_per_axis =
377+ // 0 <= lid_lin < lws0 * lws1 ==
378+ // (block_size * block_size / nelems_per_wi) ==
379+ // (block_size/private_tile_size)**2
380+ constexpr std::uint16_t n_private_tiles_per_axis =
383381 block_size / private_tile_size;
384- const std::uint32_t local_tile_id0 =
382+ const std::uint16_t local_tile_id0 =
385383 lid_lin / n_private_tiles_per_axis;
386- const std::uint32_t local_tile_id1 =
384+ const std::uint16_t local_tile_id1 =
387385 lid_lin - local_tile_id0 * n_private_tiles_per_axis;
388386
389387 if (local_tile_id0 <= local_tile_id1) {
390- for (std::uint32_t pr_i0 = 0 ; pr_i0 < private_tile_size;
388+ for (std::uint16_t pr_i0 = 0 ; pr_i0 < private_tile_size;
391389 ++pr_i0)
392390 {
393- for (std::uint32_t pr_i1 = 0 ; pr_i1 < private_tile_size;
391+ for (std::uint16_t pr_i1 = 0 ; pr_i1 < private_tile_size;
394392 ++pr_i1)
395393 {
396- const std::uint32_t t0_offset =
394+ const std::uint16_t t0_offset =
397395 local_tile_id0 * private_tile_size;
398- const std::uint32_t t1_offset =
396+ const std::uint16_t t1_offset =
399397 local_tile_id1 * private_tile_size;
400398
401- const std::uint32_t pr_offset =
399+ const std::uint16_t pr_offset =
402400 pr_i1 * private_tile_size + pr_i0;
403- const std::uint32_t rel_offset =
401+ const std::uint16_t rel_offset =
404402 pr_i0 + pr_i1 * block_size;
405403
406404 // read (local_tile_id0, local_tile_id1)
407- const std::uint32_t local_01_offset =
405+ const std::uint16_t local_01_offset =
408406 (t0_offset + t1_offset * block_size) + rel_offset;
409407 private_block_01[pr_offset] =
410408 local_block[local_01_offset];
411409
412410 // read (local_tile_id1, local_tile_id0)
413- const std::uint32_t local_10_offset =
411+ const std::uint16_t local_10_offset =
414412 (t1_offset + t0_offset * block_size) + rel_offset;
415413 private_block_10[pr_offset] =
416414 local_block[local_10_offset];
@@ -422,20 +420,20 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
422420 sycl::memory_scope::work_group);
423421
424422 if (local_tile_id0 <= local_tile_id1) {
425- for (std::uint32_t pr_i0 = 0 ; pr_i0 < private_tile_size;
423+ for (std::uint16_t pr_i0 = 0 ; pr_i0 < private_tile_size;
426424 ++pr_i0)
427425 {
428- for (std::uint32_t pr_i1 = 0 ; pr_i1 < private_tile_size;
426+ for (std::uint16_t pr_i1 = 0 ; pr_i1 < private_tile_size;
429427 ++pr_i1)
430428 {
431- const std::uint32_t t0_offset =
429+ const std::uint16_t t0_offset =
432430 local_tile_id0 * private_tile_size;
433- const std::uint32_t t1_offset =
431+ const std::uint16_t t1_offset =
434432 local_tile_id1 * private_tile_size;
435- const std::uint32_t pr_offset =
433+ const std::uint16_t pr_offset =
436434 pr_i0 * private_tile_size + pr_i1;
437435
438- const std::uint32_t rel_offset =
436+ const std::uint16_t rel_offset =
439437 pr_i0 + pr_i1 * block_size;
440438
441439 // write back permuted private blocks
@@ -444,7 +442,7 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
444442 local_block[local_01_offset] =
445443 private_block_10[pr_offset];
446444
447- const std::uint32_t local_10_offset =
445+ const std::uint16_t local_10_offset =
448446 (t1_offset + t0_offset * block_size) + rel_offset;
449447 local_block[local_10_offset] =
450448 private_block_01[pr_offset];
@@ -461,8 +459,8 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
461459 const std::size_t dst_tile_start1 = src_tile_start1;
462460
463461 if (local_dim0 == block_size && local_dim1 == block_size) {
464- const std::uint32_t dst_i0 = src_i1;
465- const std::uint32_t dst_i1 = src_i0;
462+ const std::uint16_t dst_i0 = src_i1;
463+ const std::uint16_t dst_i1 = src_i0;
466464
467465 const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
468466 const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
@@ -471,11 +469,11 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
471469 dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1 ;
472470 const std::size_t pr_step_dst = lws1 * dst_stride;
473471
474- const std::uint32_t _local_offset0 =
472+ const std::uint16_t _local_offset0 =
475473 dst_i0 * block_size + dst_i1;
476- const std::uint32_t _pr_step_local = lws1 * block_size;
474+ const std::uint16_t _pr_step_local = lws1 * block_size;
477475
478- for (std::uint32_t pr_id = 0 ; pr_id < nelems_per_wi; ++pr_id) {
476+ for (std::uint16_t pr_id = 0 ; pr_id < nelems_per_wi; ++pr_id) {
479477 if ((dst_gid1 < n) && ((dst_gid0 + pr_id * lws1) < n)) {
480478 dst_tp[dst_offset0 + pr_step_dst * pr_id] =
481479 local_block[_local_offset0 +
@@ -485,24 +483,24 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
485483 }
486484 else {
487485 // map local_linear_id into (local_dim0, local_dim1)
488- for (std::uint32_t el_id = lid_lin;
486+ for (std::uint16_t el_id = lid_lin;
489487 el_id < local_dim0 * local_dim1; el_id += lws0 * lws1)
490488 {
491489
492490 // 0 <= local_i0 < local_dim0
493- const std::uint32_t loc_i0 = el_id / local_dim1;
491+ const std::uint16_t loc_i0 = el_id / local_dim1;
494492 // 0 <= local_i1 < local_dim1
495- const std::uint32_t loc_i1 = el_id - loc_i0 * local_dim1;
493+ const std::uint16_t loc_i1 = el_id - loc_i0 * local_dim1;
496494
497- const std::uint32_t dst_i0 = loc_i0;
498- const std::uint32_t dst_i1 = loc_i1;
495+ const std::uint16_t dst_i0 = loc_i0;
496+ const std::uint16_t dst_i1 = loc_i1;
499497
500498 const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
501499 const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
502500
503501 const std::size_t dst_offset =
504502 dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1 ;
505- const std::uint32_t local_offset =
503+ const std::uint16_t local_offset =
506504 loc_i0 * block_size + loc_i1;
507505
508506 if ((dst_gid1 < n) && (dst_gid0 < n)) {
0 commit comments