3232#include " dpctl_tensor_types.hpp"
3333#include " kernels/alignment.hpp"
3434#include " utils/offset_utils.hpp"
35+ #include " utils/sycl_utils.hpp"
3536#include " utils/type_utils.hpp"
3637
3738namespace dpctl
@@ -50,15 +51,18 @@ using dpctl::tensor::kernels::alignment_utils::
5051using dpctl::tensor::kernels::alignment_utils::is_aligned;
5152using dpctl::tensor::kernels::alignment_utils::required_alignment;
5253
54+ using dpctl::tensor::sycl_utils::sub_group_load;
55+ using dpctl::tensor::sycl_utils::sub_group_store;
56+
5357template <typename T, typename condT, typename IndexerT>
5458class where_strided_kernel ;
55- template <typename T, typename condT, int vec_sz, int n_vecs>
59+ template <typename T, typename condT, std:: uint8_t vec_sz, std:: uint8_t n_vecs>
5660class where_contig_kernel ;
5761
5862template <typename T,
5963 typename condT,
60- int vec_sz = 4 ,
61- int n_vecs = 2 ,
64+ std:: uint8_t vec_sz = 4u ,
65+ std:: uint8_t n_vecs = 2u ,
6266 bool enable_sg_loadstore = true >
6367class WhereContigFunctor
6468{
@@ -82,42 +86,40 @@ class WhereContigFunctor
8286
8387 void operator ()(sycl::nd_item<1 > ndit) const
8488 {
89+ constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
90+
8591 using dpctl::tensor::type_utils::is_complex;
8692 if constexpr (!enable_sg_loadstore || is_complex<condT>::value ||
8793 is_complex<T>::value)
8894 {
89- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
90- size_t base = ndit.get_global_linear_id ();
91-
92- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
93- for (size_t offset = base;
94- offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
95- offset += sgSize)
96- {
95+ const std::uint16_t sgSize =
96+ ndit.get_sub_group ().get_local_range ()[0 ];
97+ const size_t gid = ndit.get_global_linear_id ();
98+
99+ const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi;
100+ const size_t start =
101+ (gid / sgSize) * (nelems_per_sg - sgSize) + gid;
102+ const size_t end = std::min (nelems, start + nelems_per_sg);
103+ for (size_t offset = start; offset < end; offset += sgSize) {
97104 using dpctl::tensor::type_utils::convert_impl;
98- bool check = convert_impl<bool , condT>(cond_p[offset]);
105+ const bool check = convert_impl<bool , condT>(cond_p[offset]);
99106 dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
100107 }
101108 }
102109 else {
103110 auto sg = ndit.get_sub_group ();
104- std::uint8_t sgSize = sg.get_local_range ()[0 ];
105- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
106- size_t base = n_vecs * vec_sz *
107- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
108- sg.get_group_id ()[0 ] * max_sgSize);
109-
110- if (base + n_vecs * vec_sz * sgSize < nelems &&
111- sgSize == max_sgSize)
112- {
111+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
112+
113+ const size_t base =
114+ nelems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
115+ sg.get_group_id ()[0 ] * sgSize);
116+
117+ if (base + nelems_per_wi * sgSize < nelems) {
113118 sycl::vec<T, vec_sz> dst_vec;
114- sycl::vec<T, vec_sz> x1_vec;
115- sycl::vec<T, vec_sz> x2_vec;
116- sycl::vec<condT, vec_sz> cond_vec;
117119
118120#pragma unroll
119121 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
120- auto idx = base + it * sgSize;
122+ const size_t idx = base + it * sgSize;
121123 auto x1_multi_ptr = sycl::address_space_cast<
122124 sycl::access::address_space::global_space,
123125 sycl::access::decorated::yes>(&x1_p[idx]);
@@ -131,20 +133,22 @@ class WhereContigFunctor
131133 sycl::access::address_space::global_space,
132134 sycl::access::decorated::yes>(&dst_p[idx]);
133135
134- x1_vec = sg.load <vec_sz>(x1_multi_ptr);
135- x2_vec = sg.load <vec_sz>(x2_multi_ptr);
136- cond_vec = sg.load <vec_sz>(cond_multi_ptr);
136+ const sycl::vec<T, vec_sz> x1_vec =
137+ sub_group_load<vec_sz>(sg, x1_multi_ptr);
138+ const sycl::vec<T, vec_sz> x2_vec =
139+ sub_group_load<vec_sz>(sg, x2_multi_ptr);
140+ const sycl::vec<condT, vec_sz> cond_vec =
141+ sub_group_load<vec_sz>(sg, cond_multi_ptr);
137142#pragma unroll
138143 for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
139144 dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
140145 }
141- sg. store <vec_sz>(dst_multi_ptr , dst_vec);
146+ sub_group_store <vec_sz>(sg , dst_vec, dst_multi_ptr );
142147 }
143148 }
144149 else {
145- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
146- k += sgSize)
147- {
150+ const size_t lane_id = sg.get_local_id ()[0 ];
151+ for (size_t k = base + lane_id; k < nelems; k += sgSize) {
148152 dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
149153 }
150154 }
@@ -179,8 +183,8 @@ sycl::event where_contig_impl(sycl::queue &q,
179183 cgh.depends_on (depends);
180184
181185 size_t lws = 64 ;
182- constexpr unsigned int vec_sz = 4 ;
183- constexpr unsigned int n_vecs = 2 ;
186+ constexpr std:: uint8_t vec_sz = 4u ;
187+ constexpr std:: uint8_t n_vecs = 2u ;
184188 const size_t n_groups =
185189 ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
186190 const auto gws_range = sycl::range<1 >(n_groups * lws);
0 commit comments