@@ -60,6 +60,22 @@ namespace dpnp::extensions::ufunc
6060
6161namespace impl
6262{
63+
64+ template <typename T>
65+ struct value_type_of
66+ {
67+ using type = T;
68+ };
69+
70+ template <typename T>
71+ struct value_type_of <std::complex <T>>
72+ {
73+ using type = T;
74+ };
75+
76+ template <typename T>
77+ using value_type_of_t = typename value_type_of<T>::type;
78+
6379typedef sycl::event (*nan_to_num_fn_ptr_t )(sycl::queue &,
6480 int ,
6581 size_t ,
@@ -87,30 +103,18 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
87103 py::ssize_t dst_offset,
88104 const std::vector<sycl::event> &depends)
89105{
90- sycl::event to_num_ev;
91-
92- using dpctl::tensor::type_utils::is_complex;
93- if constexpr (is_complex<T>::value) {
94- using realT = typename T::value_type;
95- realT nan_v = py::cast<realT>(py_nan);
96- realT posinf_v = py::cast<realT>(py_posinf);
97- realT neginf_v = py::cast<realT>(py_neginf);
98-
99- using dpnp::kernels::nan_to_num::nan_to_num_impl;
100- to_num_ev = nan_to_num_impl<T, realT>(
101- exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
102- arg_offset, dst_p, dst_offset, depends);
103- }
104- else {
105- T nan_v = py::cast<T>(py_nan);
106- T posinf_v = py::cast<T>(py_posinf);
107- T neginf_v = py::cast<T>(py_neginf);
108-
109- using dpnp::kernels::nan_to_num::nan_to_num_impl;
110- to_num_ev = nan_to_num_impl<T, T>(
111- exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
112- arg_offset, dst_p, dst_offset, depends);
113- }
106+ using dpctl::tensor::type_utils::is_complex_v;
107+ using scT = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
108+
109+ scT nan_v = py::cast<scT>(py_nan);
110+ scT posinf_v = py::cast<scT>(py_posinf);
111+ scT neginf_v = py::cast<scT>(py_neginf);
112+
113+ using dpnp::kernels::nan_to_num::nan_to_num_impl;
114+ sycl::event to_num_ev = nan_to_num_impl<T, scT>(
115+ exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
116+ arg_offset, dst_p, dst_offset, depends);
117+
114118 return to_num_ev;
115119}
116120
@@ -134,28 +138,17 @@ sycl::event nan_to_num_contig_call(sycl::queue &exec_q,
134138 char *dst_p,
135139 const std::vector<sycl::event> &depends)
136140{
137- sycl::event to_num_contig_ev;
138-
139- using dpctl::tensor::type_utils::is_complex;
140- if constexpr (is_complex<T>::value) {
141- using realT = typename T::value_type;
142- realT nan_v = py::cast<realT>(py_nan);
143- realT posinf_v = py::cast<realT>(py_posinf);
144- realT neginf_v = py::cast<realT>(py_neginf);
145-
146- using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147- to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148- exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149- }
150- else {
151- T nan_v = py::cast<T>(py_nan);
152- T posinf_v = py::cast<T>(py_posinf);
153- T neginf_v = py::cast<T>(py_neginf);
154-
155- using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156- to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157- exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158- }
141+ using dpctl::tensor::type_utils::is_complex_v;
142+ using scT = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
143+
144+ scT nan_v = py::cast<scT>(py_nan);
145+ scT posinf_v = py::cast<scT>(py_posinf);
146+ scT neginf_v = py::cast<scT>(py_neginf);
147+
148+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
149+ sycl::event to_num_contig_ev = nan_to_num_contig_impl<T, scT>(
150+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
151+
159152 return to_num_contig_ev;
160153}
161154
0 commit comments