Skip to content

Commit 62e2bfd

Browse files
committed
Add modf implementation to ufunc extension
1 parent 552c242 commit 62e2bfd

File tree

4 files changed

+299
-0
lines changed

4 files changed

+299
-0
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ set(_elementwise_sources
4747
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
50+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/modf.cpp
5051
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/nan_to_num.cpp
5152
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp
5253
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/sinc.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "lcm.hpp"
4747
#include "ldexp.hpp"
4848
#include "logaddexp2.hpp"
49+
#include "modf.hpp"
4950
#include "nan_to_num.hpp"
5051
#include "radians.hpp"
5152
#include "sinc.hpp"
@@ -78,6 +79,7 @@ void init_elementwise_functions(py::module_ m)
7879
init_lcm(m);
7980
init_ldexp(m);
8081
init_logaddexp2(m);
82+
init_modf(m);
8183
init_nan_to_num(m);
8284
init_radians(m);
8385
init_sinc(m);
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#include <cstdint>
30+
#include <type_traits>
31+
#include <utility>
32+
#include <vector>
33+
34+
#include <sycl/sycl.hpp>
35+
36+
#include "dpctl4pybind11.hpp"
37+
38+
#include "kernels/elementwise_functions/modf.hpp"
39+
#include "modf.hpp"
40+
41+
// include a local copy of elementwise common header from dpctl tensor:
42+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
43+
// TODO: replace by including dpctl header once available
44+
#include "../../elementwise_functions/elementwise_functions.hpp"
45+
46+
#include "../../elementwise_functions/common.hpp"
47+
#include "../../elementwise_functions/type_dispatch_building.hpp"
48+
49+
// utils extension header
50+
#include "ext/common.hpp"
51+
52+
// dpctl tensor headers
53+
#include "kernels/elementwise_functions/common.hpp"
54+
#include "utils/type_dispatch.hpp"
55+
56+
namespace dpnp::extensions::ufunc
57+
{
58+
namespace py = pybind11;
59+
namespace py_int = dpnp::extensions::py_internal;
60+
61+
namespace impl
62+
{
63+
namespace ew_cmn_ns = dpnp::extensions::py_internal::elementwise_common;
64+
namespace td_int_ns = py_int::type_dispatch;
65+
namespace td_ns = dpctl::tensor::type_dispatch;
66+
67+
using dpnp::kernels::modf::ModfFunctor;
68+
using ext::common::init_dispatch_vector;
69+
70+
template <typename T>
71+
struct ModfOutputType
72+
{
73+
using table_type = std::disjunction< // disjunction is C++17
74+
// feature, supported by DPC++
75+
td_int_ns::
76+
TypeMapTwoResultsEntry<T, sycl::half, sycl::half, sycl::half>,
77+
td_int_ns::TypeMapTwoResultsEntry<T, float, float, float>,
78+
td_int_ns::TypeMapTwoResultsEntry<T, double, double, double>,
79+
td_int_ns::DefaultTwoResultsEntry<void>>;
80+
using value_type1 = typename table_type::result_type1;
81+
using value_type2 = typename table_type::result_type2;
82+
};
83+
84+
// contiguous implementation
85+
86+
template <typename argTy,
87+
typename resTy1 = argTy,
88+
typename resTy2 = argTy,
89+
std::uint8_t vec_sz = 4u,
90+
std::uint8_t n_vecs = 2u,
91+
bool enable_sg_loadstore = true>
92+
using ModfContigFunctor =
93+
ew_cmn_ns::UnaryTwoOutputsContigFunctor<argTy,
94+
resTy1,
95+
resTy2,
96+
ModfFunctor<argTy, resTy1, resTy2>,
97+
vec_sz,
98+
n_vecs,
99+
enable_sg_loadstore>;
100+
101+
// strided implementation
102+
103+
template <typename argTy, typename resTy1, typename resTy2, typename IndexerT>
104+
using ModfStridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
105+
argTy,
106+
resTy1,
107+
resTy2,
108+
IndexerT,
109+
ModfFunctor<argTy, resTy1, resTy2>>;
110+
111+
template <typename T1,
112+
typename T2,
113+
typename T3,
114+
unsigned int vec_sz,
115+
unsigned int n_vecs>
116+
class modf_contig_kernel;
117+
118+
template <typename argTy>
119+
sycl::event modf_contig_impl(sycl::queue &exec_q,
120+
size_t nelems,
121+
const char *arg_p,
122+
char *res1_p,
123+
char *res2_p,
124+
const std::vector<sycl::event> &depends = {})
125+
{
126+
return ew_cmn_ns::unary_two_outputs_contig_impl<
127+
argTy, ModfOutputType, ModfContigFunctor, modf_contig_kernel>(
128+
exec_q, nelems, arg_p, res1_p, res2_p, depends);
129+
}
130+
131+
template <typename fnT, typename T>
132+
struct ModfContigFactory
133+
{
134+
fnT get()
135+
{
136+
if constexpr (std::is_same_v<typename ModfOutputType<T>::value_type1,
137+
void> ||
138+
std::is_same_v<typename ModfOutputType<T>::value_type2,
139+
void>)
140+
{
141+
fnT fn = nullptr;
142+
return fn;
143+
}
144+
else {
145+
fnT fn = modf_contig_impl<T>;
146+
return fn;
147+
}
148+
}
149+
};
150+
151+
template <typename T1, typename T2, typename T3, typename T4>
152+
class modf_strided_kernel;
153+
154+
template <typename argTy>
155+
sycl::event
156+
modf_strided_impl(sycl::queue &exec_q,
157+
size_t nelems,
158+
int nd,
159+
const ssize_t *shape_and_strides,
160+
const char *arg_p,
161+
ssize_t arg_offset,
162+
char *res1_p,
163+
ssize_t res1_offset,
164+
char *res2_p,
165+
ssize_t res2_offset,
166+
const std::vector<sycl::event> &depends,
167+
const std::vector<sycl::event> &additional_depends)
168+
{
169+
return ew_cmn_ns::unary_two_outputs_strided_impl<
170+
argTy, ModfOutputType, ModfStridedFunctor, modf_strided_kernel>(
171+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p,
172+
res1_offset, res2_p, res2_offset, depends, additional_depends);
173+
}
174+
175+
template <typename fnT, typename T>
176+
struct ModfStridedFactory
177+
{
178+
fnT get()
179+
{
180+
if constexpr (std::is_same_v<typename ModfOutputType<T>::value_type1,
181+
void> ||
182+
std::is_same_v<typename ModfOutputType<T>::value_type2,
183+
void>)
184+
{
185+
fnT fn = nullptr;
186+
return fn;
187+
}
188+
else {
189+
fnT fn = modf_strided_impl<T>;
190+
return fn;
191+
}
192+
}
193+
};
194+
195+
template <typename fnT, typename T>
196+
struct ModfTypeMapFactory
197+
{
198+
/*! @brief get typeid for output type of sycl::modf(T x) */
199+
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value,
200+
std::pair<int, int>>
201+
get()
202+
{
203+
using rT1 = typename ModfOutputType<T>::value_type1;
204+
using rT2 = typename ModfOutputType<T>::value_type2;
205+
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(),
206+
td_ns::GetTypeid<rT2>{}.get());
207+
}
208+
};
209+
210+
using ew_cmn_ns::unary_two_outputs_contig_impl_fn_ptr_t;
211+
using ew_cmn_ns::unary_two_outputs_strided_impl_fn_ptr_t;
212+
213+
static unary_two_outputs_contig_impl_fn_ptr_t
214+
modf_contig_dispatch_vector[td_ns::num_types];
215+
static std::pair<int, int> modf_output_typeid_vector[td_ns::num_types];
216+
static unary_two_outputs_strided_impl_fn_ptr_t
217+
modf_strided_dispatch_vector[td_ns::num_types];
218+
219+
void populate_modf_dispatch_vectors(void)
220+
{
221+
init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t,
222+
ModfContigFactory>(modf_contig_dispatch_vector);
223+
init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t,
224+
ModfStridedFactory>(modf_strided_dispatch_vector);
225+
init_dispatch_vector<std::pair<int, int>, ModfTypeMapFactory>(
226+
modf_output_typeid_vector);
227+
};
228+
} // namespace impl
229+
230+
void init_modf(py::module_ m)
231+
{
232+
using arrayT = dpctl::tensor::usm_ndarray;
233+
using event_vecT = std::vector<sycl::event>;
234+
{
235+
impl::populate_modf_dispatch_vectors();
236+
using impl::modf_contig_dispatch_vector;
237+
using impl::modf_output_typeid_vector;
238+
using impl::modf_strided_dispatch_vector;
239+
240+
auto modf_pyapi = [&](const arrayT &src, const arrayT &dst1,
241+
const arrayT &dst2, sycl::queue &exec_q,
242+
const event_vecT &depends = {}) {
243+
return py_int::py_unary_two_outputs_ufunc(
244+
src, dst1, dst2, exec_q, depends, modf_output_typeid_vector,
245+
modf_contig_dispatch_vector, modf_strided_dispatch_vector);
246+
};
247+
m.def("_modf", modf_pyapi, "", py::arg("src"), py::arg("dst1"),
248+
py::arg("dst2"), py::arg("sycl_queue"),
249+
py::arg("depends") = py::list());
250+
251+
auto modf_result_type_pyapi = [&](const py::dtype &dtype) {
252+
return py_int::py_unary_two_outputs_ufunc_result_type(
253+
dtype, modf_output_typeid_vector);
254+
};
255+
m.def("_modf_result_type", modf_result_type_pyapi);
256+
}
257+
}
258+
} // namespace dpnp::extensions::ufunc
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <pybind11/pybind11.h>
32+
33+
namespace py = pybind11;
34+
35+
namespace dpnp::extensions::ufunc
36+
{
37+
void init_modf(py::module_ m);
38+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)