3131
3232#include < sycl/sycl.hpp>
3333// dpctl tensor headers
34+ #include " kernels/alignment.hpp"
3435#include " kernels/dpctl_tensor_types.hpp"
3536#include " utils/offset_utils.hpp"
37+ #include " utils/sycl_utils.hpp"
3638#include " utils/type_utils.hpp"
3739
3840namespace dpnp ::kernels::nan_to_num
@@ -49,6 +51,14 @@ inline T to_num(const T v, const T nan, const T posinf, const T neginf)
4951template <typename T, typename scT, typename InOutIndexerT>
5052struct NanToNumFunctor
5153{
54+ private:
55+ const T *inp_ = nullptr ;
56+ T *out_ = nullptr ;
57+ const InOutIndexerT inp_out_indexer_;
58+ const scT nan_;
59+ const scT posinf_;
60+ const scT neginf_;
61+
5262public:
5363 NanToNumFunctor (const T *inp,
5464 T *out,
@@ -80,18 +90,104 @@ struct NanToNumFunctor
8090 out_[out_offset] = to_num (inp_[inp_offset], nan_, posinf_, neginf_);
8191 }
8292 }
93+ };
8394
95+ template <typename T,
96+ typename scT,
97+ std::uint8_t vec_sz = 4u ,
98+ std::uint8_t n_vecs = 2u ,
99+ bool enable_sg_loadstore = true >
100+ struct NanToNumContigFunctor
101+ {
84102private:
85- const T *inp_ = nullptr ;
103+ const T *in_ = nullptr ;
86104 T *out_ = nullptr ;
87- const InOutIndexerT inp_out_indexer_ ;
105+ std:: size_t nelems_ ;
88106 const scT nan_;
89107 const scT posinf_;
90108 const scT neginf_;
91- };
92109
93- template <typename T>
94- class NanToNumKernel ;
110+ public:
111+ NanToNumContigFunctor (const T *in,
112+ T *out,
113+ const std::size_t n_elems,
114+ const scT nan,
115+ const scT posinf,
116+ const scT neginf)
117+ : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf),
118+ neginf_ (neginf)
119+ {
120+ }
121+
122+ void operator ()(sycl::nd_item<1 > ndit) const
123+ {
124+ constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
125+ /* Each work-item processes vec_sz elements, contiguous in memory */
126+ /* NOTE: work-group size must be divisible by sub-group size */
127+
128+ using dpctl::tensor::type_utils::is_complex_v;
129+ if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
130+ auto sg = ndit.get_sub_group ();
131+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
132+ const std::size_t base =
133+ elems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
134+ sg.get_group_id ()[0 ] * sgSize);
135+
136+ if (base + elems_per_wi * sgSize < nelems_) {
137+ using dpctl::tensor::sycl_utils::sub_group_load;
138+ using dpctl::tensor::sycl_utils::sub_group_store;
139+ #pragma unroll
140+ for (std::uint8_t it = 0 ; it < elems_per_wi; it += vec_sz) {
141+ const std::size_t offset = base + it * sgSize;
142+ auto in_multi_ptr = sycl::address_space_cast<
143+ sycl::access::address_space::global_space,
144+ sycl::access::decorated::yes>(&in_[offset]);
145+ auto out_multi_ptr = sycl::address_space_cast<
146+ sycl::access::address_space::global_space,
147+ sycl::access::decorated::yes>(&out_[offset]);
148+
149+ sycl::vec<T, vec_sz> arg_vec =
150+ sub_group_load<vec_sz>(sg, in_multi_ptr);
151+ #pragma unroll
152+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
153+ arg_vec[k] = to_num (arg_vec[k], nan_, posinf_, neginf_);
154+ }
155+ sub_group_store<vec_sz>(sg, arg_vec, out_multi_ptr);
156+ }
157+ }
158+ else {
159+ const std::size_t lane_id = sg.get_local_id ()[0 ];
160+ for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
161+ out_[k] = to_num (in_[k], nan_, posinf_, neginf_);
162+ }
163+ }
164+ }
165+ else {
166+ const std::uint16_t sgSize =
167+ ndit.get_sub_group ().get_local_range ()[0 ];
168+ const std::size_t gid = ndit.get_global_linear_id ();
169+ const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
170+
171+ const std::size_t start =
172+ (gid / sgSize) * (elems_per_sg - sgSize) + gid;
173+ const std::size_t end = std::min (nelems_, start + elems_per_sg);
174+ for (std::size_t offset = start; offset < end; offset += sgSize) {
175+ if constexpr (is_complex_v<T>) {
176+ using realT = typename T::value_type;
177+ static_assert (std::is_same_v<realT, scT>);
178+
179+ T z = in_[offset];
180+ realT x = to_num (z.real (), nan_, posinf_, neginf_);
181+ realT y = to_num (z.imag (), nan_, posinf_, neginf_);
182+ out_[offset] = T{x, y};
183+ }
184+ else {
185+ out_[offset] = to_num (in_[offset], nan_, posinf_, neginf_);
186+ }
187+ }
188+ }
189+ }
190+ };
95191
96192template <typename T, typename scT>
97193sycl::event nan_to_num_impl (sycl::queue &q,
@@ -119,48 +215,69 @@ sycl::event nan_to_num_impl(sycl::queue &q,
119215 sycl::event comp_ev = q.submit ([&](sycl::handler &cgh) {
120216 cgh.depends_on (depends);
121217
122- using KernelName = NanToNumKernel<T >;
123- cgh.parallel_for <KernelName >(
124- {nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
125- in_tp, out_tp, indexer, nan, posinf, neginf));
218+ using NanToNumFunc = NanToNumFunctor<T, scT, InOutIndexerT >;
219+ cgh.parallel_for <NanToNumFunc >(
220+ {nelems},
221+ NanToNumFunc ( in_tp, out_tp, indexer, nan, posinf, neginf));
126222 });
127223 return comp_ev;
128224}
129225
130- template <typename T>
131- class NanToNumContigKernel ;
132-
133- template < typename T, typename scT >
134- sycl::event nan_to_num_contig_impl (sycl::queue &q ,
135- const size_t nelems,
226+ template <typename T,
227+ typename scT,
228+ std:: uint8_t vec_sz = 4u ,
229+ std:: uint8_t n_vecs = 2u >
230+ sycl::event nan_to_num_contig_impl (sycl::queue &exec_q ,
231+ std:: size_t nelems,
136232 const scT nan,
137233 const scT posinf,
138234 const scT neginf,
139235 const char *in_cp,
140236 char *out_cp,
141- const std::vector<sycl::event> &depends)
237+ const std::vector<sycl::event> &depends = {} )
142238{
143- dpctl::tensor::type_utils::validate_type_for_device<T>(q);
239+ constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
240+ const std::size_t n_work_items_needed = nelems / elems_per_wi;
241+ const std::size_t empirical_threshold = std::size_t (1 ) << 21 ;
242+ const std::size_t lws = (n_work_items_needed <= empirical_threshold)
243+ ? std::size_t (128 )
244+ : std::size_t (256 );
245+
246+ const std::size_t n_groups =
247+ ((nelems + lws * elems_per_wi - 1 ) / (lws * elems_per_wi));
248+ const auto gws_range = sycl::range<1 >(n_groups * lws);
249+ const auto lws_range = sycl::range<1 >(lws);
144250
145251 const T *in_tp = reinterpret_cast <const T *>(in_cp);
146252 T *out_tp = reinterpret_cast <T *>(out_cp);
147253
148- using dpctl::tensor::offset_utils::NoOpIndexer;
149- using InOutIndexerT =
150- dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<NoOpIndexer,
151- NoOpIndexer>;
152- constexpr NoOpIndexer in_indexer{};
153- constexpr NoOpIndexer out_indexer{};
154- constexpr InOutIndexerT indexer{in_indexer, out_indexer};
155-
156- sycl::event comp_ev = q.submit ([&](sycl::handler &cgh) {
254+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
157255 cgh.depends_on (depends);
158256
159- using KernelName = NanToNumContigKernel<T>;
160- cgh.parallel_for <KernelName>(
161- {nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
162- in_tp, out_tp, indexer, nan, posinf, neginf));
257+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
258+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
259+ if (is_aligned<required_alignment>(in_tp) &&
260+ is_aligned<required_alignment>(out_tp))
261+ {
262+ constexpr bool enable_sg_loadstore = true ;
263+ using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
264+ enable_sg_loadstore>;
265+
266+ cgh.parallel_for <NanToNumFunc>(
267+ sycl::nd_range<1 >(gws_range, lws_range),
268+ NanToNumFunc (in_tp, out_tp, nelems, nan, posinf, neginf));
269+ }
270+ else {
271+ constexpr bool disable_sg_loadstore = false ;
272+ using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
273+ disable_sg_loadstore>;
274+
275+ cgh.parallel_for <NanToNumFunc>(
276+ sycl::nd_range<1 >(gws_range, lws_range),
277+ NanToNumFunc (in_tp, out_tp, nelems, nan, posinf, neginf));
278+ }
163279 });
280+
164281 return comp_ev;
165282}
166283
0 commit comments