Skip to content

Commit 13ee6a1

Browse files
committed
Add kernel for nan_to_num to ufunc extension
1 parent 11487c5 commit 13ee6a1

File tree

6 files changed

+487
-19
lines changed

6 files changed

+487
-19
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(_elementwise_sources
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4040
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
41+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/nan_to_num.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/sinc.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/spacing.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "lcm.hpp"
3939
#include "ldexp.hpp"
4040
#include "logaddexp2.hpp"
41+
#include "nan_to_num.hpp"
4142
#include "radians.hpp"
4243
#include "sinc.hpp"
4344
#include "spacing.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_lcm(m);
6566
init_ldexp(m);
6667
init_logaddexp2(m);
68+
init_nan_to_num(m);
6769
init_radians(m);
6870
init_sinc(m);
6971
init_spacing(m);
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <stdexcept>
27+
28+
#include <sycl/sycl.hpp>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/numpy.h>
32+
#include <pybind11/pybind11.h>
33+
#include <pybind11/stl.h>
34+
35+
#include "kernels/elementwise_functions/nan_to_num.hpp"
36+
37+
#include "../../elementwise_functions/simplify_iteration_space.hpp"
38+
39+
// dpctl tensor headers
40+
#include "utils/memory_overlap.hpp"
41+
#include "utils/offset_utils.hpp"
42+
#include "utils/output_validation.hpp"
43+
#include "utils/sycl_alloc_utils.hpp"
44+
#include "utils/type_dispatch.hpp"
45+
#include "utils/type_utils.hpp"
46+
47+
namespace py = pybind11;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
// declare pybind11 wrappers in py_internal namespace
51+
namespace dpnp::extensions::ufunc
52+
{
53+
54+
namespace impl
55+
{
56+
typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &,
57+
int,
58+
size_t,
59+
py::ssize_t *,
60+
const py::object &,
61+
const py::object &,
62+
const py::object &,
63+
const char *,
64+
py::ssize_t,
65+
char *,
66+
py::ssize_t,
67+
const std::vector<sycl::event> &);
68+
69+
template <typename T>
70+
sycl::event nan_to_num_call(sycl::queue &exec_q,
71+
int nd,
72+
size_t nelems,
73+
py::ssize_t *shape_strides,
74+
const py::object &py_nan,
75+
const py::object &py_posinf,
76+
const py::object &py_neginf,
77+
const char *arg_p,
78+
py::ssize_t arg_offset,
79+
char *dst_p,
80+
py::ssize_t dst_offset,
81+
const std::vector<sycl::event> &depends)
82+
{
83+
sycl::event to_num_ev;
84+
85+
using dpctl::tensor::type_utils::is_complex;
86+
if constexpr (is_complex<T>::value) {
87+
using realT = typename T::value_type;
88+
realT nan_v = py::cast<realT>(py_nan);
89+
realT posinf_v = py::cast<realT>(py_posinf);
90+
realT neginf_v = py::cast<realT>(py_neginf);
91+
92+
using dpnp::kernels::nan_to_num::nan_to_num_impl;
93+
to_num_ev = nan_to_num_impl<T, realT>(
94+
exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
95+
arg_offset, dst_p, dst_offset, depends);
96+
}
97+
else {
98+
T nan_v = py::cast<T>(py_nan);
99+
T posinf_v = py::cast<T>(py_posinf);
100+
T neginf_v = py::cast<T>(py_neginf);
101+
102+
using dpnp::kernels::nan_to_num::nan_to_num_impl;
103+
to_num_ev = nan_to_num_impl<T, T>(
104+
exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
105+
arg_offset, dst_p, dst_offset, depends);
106+
}
107+
return to_num_ev;
108+
}
109+
110+
namespace td_ns = dpctl::tensor::type_dispatch;
111+
nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types];
112+
113+
std::pair<sycl::event, sycl::event>
114+
py_nan_to_num(const dpctl::tensor::usm_ndarray &src,
115+
const py::object &py_nan,
116+
const py::object &py_posinf,
117+
const py::object &py_neginf,
118+
const dpctl::tensor::usm_ndarray &dst,
119+
sycl::queue &q,
120+
const std::vector<sycl::event> &depends)
121+
{
122+
int src_typenum = src.get_typenum();
123+
int dst_typenum = dst.get_typenum();
124+
125+
const auto &array_types = td_ns::usm_ndarray_types();
126+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
127+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
128+
129+
if (src_typeid != dst_typeid) {
130+
throw py::value_error("Array data types are not the same.");
131+
}
132+
133+
if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
134+
throw py::value_error(
135+
"Execution queue is not compatible with allocation queues");
136+
}
137+
138+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
139+
140+
int src_nd = src.get_ndim();
141+
if (src_nd != dst.get_ndim()) {
142+
throw py::value_error("Array dimensions are not the same.");
143+
}
144+
145+
const py::ssize_t *src_shape = src.get_shape_raw();
146+
const py::ssize_t *dst_shape = dst.get_shape_raw();
147+
148+
bool shapes_equal(true);
149+
size_t nelems(1);
150+
for (int i = 0; i < src_nd; ++i) {
151+
nelems *= static_cast<size_t>(src_shape[i]);
152+
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
153+
}
154+
if (!shapes_equal) {
155+
throw py::value_error("Array shapes are not the same.");
156+
}
157+
158+
// if nelems is zero, return
159+
if (nelems == 0) {
160+
return std::make_pair(sycl::event(), sycl::event());
161+
}
162+
163+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems);
164+
165+
// check memory overlap
166+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
167+
auto const &same_logical_tensors =
168+
dpctl::tensor::overlap::SameLogicalTensors();
169+
if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
170+
throw py::value_error("Arrays index overlapping segments of memory");
171+
}
172+
173+
const char *src_data = src.get_data();
174+
char *dst_data = dst.get_data();
175+
176+
auto const &src_strides = src.get_strides_vector();
177+
auto const &dst_strides = dst.get_strides_vector();
178+
179+
using shT = std::vector<py::ssize_t>;
180+
shT simplified_shape;
181+
shT simplified_src_strides;
182+
shT simplified_dst_strides;
183+
py::ssize_t src_offset(0);
184+
py::ssize_t dst_offset(0);
185+
186+
int nd = src_nd;
187+
const py::ssize_t *shape = src_shape;
188+
189+
py_internal::simplify_iteration_space(
190+
nd, shape, src_strides, dst_strides,
191+
// output
192+
simplified_shape, simplified_src_strides, simplified_dst_strides,
193+
src_offset, dst_offset);
194+
195+
auto fn = nan_to_num_dispatch_vector[src_typeid];
196+
197+
if (fn == nullptr) {
198+
throw std::runtime_error(
199+
"nan_to_num implementation is missing for src_typeid=" +
200+
std::to_string(src_typeid));
201+
}
202+
203+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
204+
205+
std::vector<sycl::event> host_tasks{};
206+
host_tasks.reserve(2);
207+
208+
const auto &ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
209+
q, host_tasks, simplified_shape, simplified_src_strides,
210+
simplified_dst_strides);
211+
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_);
212+
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
213+
214+
if (shape_strides == nullptr) {
215+
throw std::runtime_error("Device memory allocation failed");
216+
}
217+
218+
std::vector<sycl::event> all_deps;
219+
all_deps.reserve(depends.size() + 1);
220+
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
221+
all_deps.push_back(copy_shape_ev);
222+
223+
sycl::event comp_ev =
224+
fn(q, nelems, nd, shape_strides, py_nan, py_posinf, py_neginf, src_data,
225+
src_offset, dst_data, dst_offset, all_deps);
226+
227+
// async free of shape_strides temporary
228+
auto ctx = q.get_context();
229+
sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) {
230+
cgh.depends_on(comp_ev);
231+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
232+
cgh.host_task(
233+
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
234+
});
235+
host_tasks.push_back(tmp_cleanup_ev);
236+
237+
return std::make_pair(
238+
dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks), comp_ev);
239+
}
240+
241+
namespace py_int = dpnp::extensions::py_internal;
242+
243+
/**
244+
* @brief A factory to define pairs of supported types for which
245+
* nan_to_num_call<T> function is available.
246+
*
247+
* @tparam T Type of input vector `a` and of result vector `y`.
248+
*/
249+
template <typename T>
250+
struct NanToNumOutputType
251+
{
252+
using value_type = typename std::disjunction<
253+
td_ns::TypeMapResultEntry<T, sycl::half>,
254+
td_ns::TypeMapResultEntry<T, float>,
255+
td_ns::TypeMapResultEntry<T, double>,
256+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
257+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
258+
td_ns::DefaultResultEntry<void>>::result_type;
259+
};
260+
261+
template <typename fnT, typename T>
262+
struct NanToNumFactory
263+
{
264+
fnT get()
265+
{
266+
if constexpr (std::is_same_v<typename NanToNumOutputType<T>::value_type,
267+
void>) {
268+
return nullptr;
269+
}
270+
else {
271+
using ::dpnp::extensions::ufunc::impl::nan_to_num_call;
272+
return nan_to_num_call<T>;
273+
}
274+
}
275+
};
276+
277+
void populate_nan_to_num_dispatch_vector(void)
278+
{
279+
using namespace td_ns;
280+
281+
DispatchVectorBuilder<nan_to_num_fn_ptr_t, NanToNumFactory, num_types> dvb;
282+
dvb.populate_dispatch_vector(nan_to_num_dispatch_vector);
283+
}
284+
285+
} // namespace impl
286+
287+
void init_nan_to_num(py::module_ m)
288+
{
289+
{
290+
impl::populate_nan_to_num_dispatch_vector();
291+
292+
using impl::py_nan_to_num;
293+
m.def("_nan_to_num", &py_nan_to_num, "", py::arg("src"),
294+
py::arg("py_nan"), py::arg("py_posinf"), py::arg("py_neginf"),
295+
py::arg("dst"), py::arg("sycl_queue"),
296+
py::arg("depends") = py::list());
297+
}
298+
}
299+
300+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_nan_to_num(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)