Skip to content

Commit 7a9de99

Browse files
committed
Add frexp implementation to ufunc extension
1 parent 59bd1c9 commit 7a9de99

File tree

5 files changed

+355
-0
lines changed

5 files changed

+355
-0
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(_elementwise_sources
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmax.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmin.cpp
4040
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmod.cpp
41+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/frexp.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "fmax.hpp"
3838
#include "fmin.hpp"
3939
#include "fmod.hpp"
40+
#include "frexp.hpp"
4041
#include "gcd.hpp"
4142
#include "heaviside.hpp"
4243
#include "i0.hpp"
@@ -68,6 +69,7 @@ void init_elementwise_functions(py::module_ m)
6869
init_fmax(m);
6970
init_fmin(m);
7071
init_fmod(m);
72+
init_frexp(m);
7173
init_gcd(m);
7274
init_heaviside(m);
7375
init_i0(m);
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// maxification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#include <cstdint>
30+
#include <type_traits>
31+
#include <utility>
32+
#include <vector>
33+
34+
#include <sycl/sycl.hpp>
35+
36+
#include "dpctl4pybind11.hpp"
37+
38+
#include "frexp.hpp"
39+
#include "kernels/elementwise_functions/frexp.hpp"
40+
#include "populate.hpp"
41+
42+
// include a local copy of elementwise common header from dpctl tensor:
43+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
44+
// TODO: replace by including dpctl header once available
45+
#include "../../elementwise_functions/elementwise_functions.hpp"
46+
47+
#include "../../elementwise_functions/common.hpp"
48+
#include "../../elementwise_functions/type_dispatch_building.hpp"
49+
50+
// utils extension header
51+
#include "ext/common.hpp"
52+
53+
// dpctl tensor headers
54+
#include "kernels/elementwise_functions/common.hpp"
55+
#include "utils/type_dispatch.hpp"
56+
57+
namespace dpnp::extensions::ufunc
58+
{
59+
namespace py = pybind11;
60+
namespace py_int = dpnp::extensions::py_internal;
61+
62+
namespace impl
63+
{
64+
namespace ew_cmn_ns = dpnp::extensions::py_internal::elementwise_common;
65+
namespace td_int_ns = py_int::type_dispatch;
66+
namespace td_ns = dpctl::tensor::type_dispatch;
67+
68+
using dpnp::kernels::frexp::FrexpFunctor;
69+
using ext::common::init_dispatch_vector;
70+
71+
template <typename T>
72+
struct FrexpOutputType
73+
{
74+
using table_type = std::disjunction< // disjunction is C++17
75+
// feature, supported by DPC++
76+
td_int_ns::
77+
TypeMapTwoResultsEntry<T, sycl::half, sycl::half, std::int32_t>,
78+
td_int_ns::TypeMapTwoResultsEntry<T, float, float, std::int32_t>,
79+
td_int_ns::TypeMapTwoResultsEntry<T, double, double, std::int32_t>,
80+
td_int_ns::DefaultTwoResultsEntry<void>>;
81+
using value_type1 = typename table_type::result_type1;
82+
using value_type2 = typename table_type::result_type2;
83+
};
84+
85+
// contiguous implementation
86+
87+
template <typename argTy,
88+
typename resTy1 = argTy,
89+
typename resTy2 = argTy,
90+
std::uint8_t vec_sz = 4u,
91+
std::uint8_t n_vecs = 2u,
92+
bool enable_sg_loadstore = true>
93+
using FrexpContigFunctor =
94+
ew_cmn_ns::UnaryTwoOutputsContigFunctor<argTy,
95+
resTy1,
96+
resTy2,
97+
FrexpFunctor<argTy, resTy1, resTy2>,
98+
vec_sz,
99+
n_vecs,
100+
enable_sg_loadstore>;
101+
102+
// strided implementation
103+
104+
template <typename argTy, typename resTy1, typename resTy2, typename IndexerT>
105+
using FrexpStridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
106+
argTy,
107+
resTy1,
108+
resTy2,
109+
IndexerT,
110+
FrexpFunctor<argTy, resTy1, resTy2>>;
111+
112+
template <typename T1,
113+
typename T2,
114+
typename T3,
115+
unsigned int vec_sz,
116+
unsigned int n_vecs>
117+
class frexp_contig_kernel;
118+
119+
template <typename argTy>
120+
sycl::event frexp_contig_impl(sycl::queue &exec_q,
121+
size_t nelems,
122+
const char *arg_p,
123+
char *res1_p,
124+
char *res2_p,
125+
const std::vector<sycl::event> &depends = {})
126+
{
127+
return ew_cmn_ns::unary_two_outputs_contig_impl<
128+
argTy, FrexpOutputType, FrexpContigFunctor, frexp_contig_kernel>(
129+
exec_q, nelems, arg_p, res1_p, res2_p, depends);
130+
}
131+
132+
template <typename fnT, typename T>
133+
struct FrexpContigFactory
134+
{
135+
fnT get()
136+
{
137+
if constexpr (std::is_same_v<typename FrexpOutputType<T>::value_type1,
138+
void> ||
139+
std::is_same_v<typename FrexpOutputType<T>::value_type2,
140+
void>)
141+
{
142+
fnT fn = nullptr;
143+
return fn;
144+
}
145+
else {
146+
fnT fn = frexp_contig_impl<T>;
147+
return fn;
148+
}
149+
}
150+
};
151+
152+
template <typename T1, typename T2, typename T3, typename T4>
153+
class frexp_strided_kernel;
154+
155+
template <typename argTy>
156+
sycl::event
157+
frexp_strided_impl(sycl::queue &exec_q,
158+
size_t nelems,
159+
int nd,
160+
const ssize_t *shape_and_strides,
161+
const char *arg_p,
162+
ssize_t arg_offset,
163+
char *res1_p,
164+
ssize_t res1_offset,
165+
char *res2_p,
166+
ssize_t res2_offset,
167+
const std::vector<sycl::event> &depends,
168+
const std::vector<sycl::event> &additional_depends)
169+
{
170+
return ew_cmn_ns::unary_two_outputs_strided_impl<
171+
argTy, FrexpOutputType, FrexpStridedFunctor, frexp_strided_kernel>(
172+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p,
173+
res1_offset, res2_p, res2_offset, depends, additional_depends);
174+
}
175+
176+
template <typename fnT, typename T>
177+
struct FrexpStridedFactory
178+
{
179+
fnT get()
180+
{
181+
if constexpr (std::is_same_v<typename FrexpOutputType<T>::value_type1,
182+
void> ||
183+
std::is_same_v<typename FrexpOutputType<T>::value_type2,
184+
void>)
185+
{
186+
fnT fn = nullptr;
187+
return fn;
188+
}
189+
else {
190+
fnT fn = frexp_strided_impl<T>;
191+
return fn;
192+
}
193+
}
194+
};
195+
196+
template <typename fnT, typename T>
197+
struct FrexpTypeMapFactory
198+
{
199+
/*! @brief get typeid for output type of sycl::frexp(T x) */
200+
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value,
201+
std::pair<int, int>>
202+
get()
203+
{
204+
using rT1 = typename FrexpOutputType<T>::value_type1;
205+
using rT2 = typename FrexpOutputType<T>::value_type2;
206+
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(),
207+
td_ns::GetTypeid<rT2>{}.get());
208+
}
209+
};
210+
211+
using ew_cmn_ns::unary_two_outputs_contig_impl_fn_ptr_t;
212+
using ew_cmn_ns::unary_two_outputs_strided_impl_fn_ptr_t;
213+
214+
static unary_two_outputs_contig_impl_fn_ptr_t
215+
frexp_contig_dispatch_vector[td_ns::num_types];
216+
static std::pair<int, int> frexp_output_typeid_vector[td_ns::num_types];
217+
static unary_two_outputs_strided_impl_fn_ptr_t
218+
frexp_strided_dispatch_vector[td_ns::num_types];
219+
220+
void populate_frexp_dispatch_vectors(void)
221+
{
222+
init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t,
223+
FrexpContigFactory>(frexp_contig_dispatch_vector);
224+
init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t,
225+
FrexpStridedFactory>(frexp_strided_dispatch_vector);
226+
init_dispatch_vector<std::pair<int, int>, FrexpTypeMapFactory>(
227+
frexp_output_typeid_vector);
228+
};
229+
230+
// MACRO_POPULATE_DISPATCH_TABLES(ldexp);
231+
} // namespace impl
232+
233+
void init_frexp(py::module_ m)
234+
{
235+
using arrayT = dpctl::tensor::usm_ndarray;
236+
using event_vecT = std::vector<sycl::event>;
237+
{
238+
impl::populate_frexp_dispatch_vectors();
239+
using impl::frexp_contig_dispatch_vector;
240+
using impl::frexp_output_typeid_vector;
241+
using impl::frexp_strided_dispatch_vector;
242+
243+
auto frexp_pyapi = [&](const arrayT &src, const arrayT &dst1,
244+
const arrayT &dst2, sycl::queue &exec_q,
245+
const event_vecT &depends = {}) {
246+
return py_int::py_unary_two_outputs_ufunc(
247+
src, dst1, dst2, exec_q, depends, frexp_output_typeid_vector,
248+
frexp_contig_dispatch_vector, frexp_strided_dispatch_vector);
249+
};
250+
m.def("_frexp", frexp_pyapi, "", py::arg("src"), py::arg("dst1"),
251+
py::arg("dst2"), py::arg("sycl_queue"),
252+
py::arg("depends") = py::list());
253+
254+
auto frexp_result_type_pyapi = [&](const py::dtype &dtype) {
255+
return py_int::py_unary_two_outputs_ufunc_result_type(
256+
dtype, frexp_output_typeid_vector);
257+
};
258+
m.def("_frexp_result_type", frexp_result_type_pyapi);
259+
}
260+
}
261+
} // namespace dpnp::extensions::ufunc
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <pybind11/pybind11.h>
32+
33+
namespace py = pybind11;
34+
35+
namespace dpnp::extensions::ufunc
36+
{
37+
void init_frexp(py::module_ m);
38+
} // namespace dpnp::extensions::ufunc
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <sycl/sycl.hpp>
32+
33+
namespace dpnp::kernels::frexp
34+
{
35+
template <typename argT, typename mantT, typename expT>
36+
struct FrexpFunctor
37+
{
38+
// is function constant for given argT
39+
using is_constant = typename std::false_type;
40+
// constant value, if constant
41+
// constexpr resT1 constant_value1 = resT1{};
42+
// constexpr resT2 constant_value2 = resT2{};
43+
// is function defined for sycl::vec
44+
using supports_vec = typename std::false_type;
45+
// do both argT and mantT, expT support subgroup store/load operation
46+
using supports_sg_loadstore = typename std::true_type;
47+
48+
mantT operator()(const argT &in, expT &exp) const
49+
{
50+
return sycl::frexp(in, &exp);
51+
}
52+
};
53+
} // namespace dpnp::kernels::frexp

0 commit comments

Comments
 (0)