@@ -114,8 +114,54 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
114114 return to_num_ev;
115115}
116116
117+ typedef sycl::event (*nan_to_num_contig_fn_ptr_t )(
118+ sycl::queue &,
119+ size_t ,
120+ const py::object &,
121+ const py::object &,
122+ const py::object &,
123+ const char *,
124+ char *,
125+ const std::vector<sycl::event> &);
126+
127+ template <typename T>
128+ sycl::event nan_to_num_contig_call (sycl::queue &exec_q,
129+ size_t nelems,
130+ const py::object &py_nan,
131+ const py::object &py_posinf,
132+ const py::object &py_neginf,
133+ const char *arg_p,
134+ char *dst_p,
135+ const std::vector<sycl::event> &depends)
136+ {
137+ sycl::event to_num_contig_ev;
138+
139+ using dpctl::tensor::type_utils::is_complex;
140+ if constexpr (is_complex<T>::value) {
141+ using realT = typename T::value_type;
142+ realT nan_v = py::cast<realT>(py_nan);
143+ realT posinf_v = py::cast<realT>(py_posinf);
144+ realT neginf_v = py::cast<realT>(py_neginf);
145+
146+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147+ to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149+ }
150+ else {
151+ T nan_v = py::cast<T>(py_nan);
152+ T posinf_v = py::cast<T>(py_posinf);
153+ T neginf_v = py::cast<T>(py_neginf);
154+
155+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156+ to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158+ }
159+ return to_num_contig_ev;
160+ }
161+
117162namespace td_ns = dpctl::tensor::type_dispatch;
118163nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types];
164+ nan_to_num_contig_fn_ptr_t nan_to_num_contig_dispatch_vector[td_ns::num_types];
119165
120166std::pair<sycl::event, sycl::event>
121167 py_nan_to_num (const dpctl::tensor::usm_ndarray &src,
@@ -176,6 +222,37 @@ std::pair<sycl::event, sycl::event>
176222 const char *src_data = src.get_data ();
177223 char *dst_data = dst.get_data ();
178224
225+ // handle contiguous inputs
226+ bool is_src_c_contig = src.is_c_contiguous ();
227+ bool is_src_f_contig = src.is_f_contiguous ();
228+
229+ bool is_dst_c_contig = dst.is_c_contiguous ();
230+ bool is_dst_f_contig = dst.is_f_contiguous ();
231+
232+ bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
233+ bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
234+
235+ if (both_c_contig || both_f_contig) {
236+ auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
237+
238+ if (contig_fn == nullptr ) {
239+ throw std::runtime_error (
240+ " Contiguous implementation is missing for src_typeid=" +
241+ std::to_string (src_typeid));
242+ }
243+
244+ auto comp_ev = contig_fn (q, nelems, py_nan, py_posinf, py_neginf,
245+ src_data, dst_data, depends);
246+ sycl::event ht_ev =
247+ dpctl::utils::keep_args_alive (q, {src, dst}, {comp_ev});
248+
249+ return std::make_pair (ht_ev, comp_ev);
250+ }
251+
252+ // simplify iteration space
253+ // if 1d with strides 1 - input is contig
254+ // dispatch to strided
255+
179256 auto const &src_strides = src.get_strides_vector ();
180257 auto const &dst_strides = dst.get_strides_vector ();
181258
@@ -195,6 +272,30 @@ std::pair<sycl::event, sycl::event>
195272 simplified_shape, simplified_src_strides, simplified_dst_strides,
196273 src_offset, dst_offset);
197274
275+ if (nd == 1 && simplified_src_strides[0 ] == 1 &&
276+ simplified_dst_strides[0 ] == 1 ) {
277+ // Special case of contiguous data
278+ auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
279+
280+ if (contig_fn == nullptr ) {
281+ throw std::runtime_error (
282+ " Contiguous implementation is missing for src_typeid=" +
283+ std::to_string (src_typeid));
284+ }
285+
286+ int src_elem_size = src.get_elemsize ();
287+ int dst_elem_size = dst.get_elemsize ();
288+ auto comp_ev =
289+ contig_fn (q, nelems, py_nan, py_posinf, py_neginf,
290+ src_data + src_elem_size * src_offset,
291+ dst_data + dst_elem_size * dst_offset, depends);
292+
293+ sycl::event ht_ev =
294+ dpctl::utils::keep_args_alive (q, {src, dst}, {comp_ev});
295+
296+ return std::make_pair (ht_ev, comp_ev);
297+ }
298+
198299 auto fn = nan_to_num_dispatch_vector[src_typeid];
199300
200301 if (fn == nullptr ) {
@@ -277,20 +378,41 @@ struct NanToNumFactory
277378 }
278379};
279380
280- void populate_nan_to_num_dispatch_vector (void )
381+ template <typename fnT, typename T>
382+ struct NanToNumContigFactory
383+ {
384+ fnT get ()
385+ {
386+ if constexpr (std::is_same_v<typename NanToNumOutputType<T>::value_type,
387+ void >) {
388+ return nullptr ;
389+ }
390+ else {
391+ using ::dpnp::extensions::ufunc::impl::nan_to_num_contig_call;
392+ return nan_to_num_contig_call<T>;
393+ }
394+ }
395+ };
396+
397+ void populate_nan_to_num_dispatch_vectors (void )
281398{
282399 using namespace td_ns ;
283400
284- DispatchVectorBuilder<nan_to_num_fn_ptr_t , NanToNumFactory, num_types> dvb;
285- dvb.populate_dispatch_vector (nan_to_num_dispatch_vector);
401+ DispatchVectorBuilder<nan_to_num_fn_ptr_t , NanToNumFactory, num_types> dvb1;
402+ dvb1.populate_dispatch_vector (nan_to_num_dispatch_vector);
403+
404+ DispatchVectorBuilder<nan_to_num_contig_fn_ptr_t , NanToNumContigFactory,
405+ num_types>
406+ dvb2;
407+ dvb2.populate_dispatch_vector (nan_to_num_contig_dispatch_vector);
286408}
287409
288410} // namespace impl
289411
290412void init_nan_to_num (py::module_ m)
291413{
292414 {
293- impl::populate_nan_to_num_dispatch_vector ();
415+ impl::populate_nan_to_num_dispatch_vectors ();
294416
295417 using impl::py_nan_to_num;
296418 m.def (" _nan_to_num" , &py_nan_to_num, " " , py::arg (" src" ),
0 commit comments