Skip to content

Commit ee89a6e

Browse files
committed
Add _ldexp to ufunc extension
1 parent 6c945e2 commit ee89a6e

File tree

5 files changed

+240
-0
lines changed

5 files changed

+240
-0
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ set(_elementwise_sources
3535
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
38+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
3839
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp
4041
)

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "gcd.hpp"
3636
#include "heaviside.hpp"
3737
#include "lcm.hpp"
38+
#include "ldexp.hpp"
3839
#include "logaddexp2.hpp"
3940
#include "radians.hpp"
4041

@@ -57,6 +58,7 @@ void init_elementwise_functions(py::module_ m)
5758
init_gcd(m);
5859
init_heaviside(m);
5960
init_lcm(m);
61+
init_ldexp(m);
6062
init_logaddexp2(m);
6163
init_radians(m);
6264
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// maxification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "kernels/elementwise_functions/ldexp.hpp"
31+
#include "ldexp.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "kernels/elementwise_functions/maximum.hpp"
42+
#include "utils/type_dispatch.hpp"
43+
44+
namespace dpnp::extensions::ufunc
45+
{
46+
namespace py = pybind11;
47+
namespace py_int = dpnp::extensions::py_internal;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
namespace impl
51+
{
52+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
53+
namespace max_ns = dpctl::tensor::kernels::maximum;
54+
55+
// Supports the same types table as for maximum function in dpctl
56+
// template <typename T1, typename T2>
57+
// using OutputType = max_ns::MaximumOutputType<T1, T2>;
58+
template <typename T1, typename T2>
59+
struct OutputType
60+
{
61+
using value_type = typename std::disjunction< // disjunction is C++17
62+
// feature, supported by DPC++
63+
td_ns::BinaryTypeMapResultEntry<T1,
64+
sycl::half,
65+
T2,
66+
std::int32_t,
67+
sycl::half>,
68+
td_ns::BinaryTypeMapResultEntry<T1,
69+
sycl::half,
70+
T2,
71+
std::int64_t,
72+
sycl::half>,
73+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, std::int32_t, float>,
74+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, std::int64_t, float>,
75+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, std::int32_t, double>,
76+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, std::int64_t, double>,
77+
td_ns::DefaultResultEntry<void>>::result_type;
78+
};
79+
80+
using dpnp::kernels::ldexp::LdexpFunctor;
81+
82+
template <typename argT1,
83+
typename argT2,
84+
typename resT,
85+
unsigned int vec_sz = 4,
86+
unsigned int n_vecs = 2,
87+
bool enable_sg_loadstore = true>
88+
using ContigFunctor =
89+
ew_cmn_ns::BinaryContigFunctor<argT1,
90+
argT2,
91+
resT,
92+
LdexpFunctor<argT1, argT2, resT>,
93+
vec_sz,
94+
n_vecs,
95+
enable_sg_loadstore>;
96+
97+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
98+
using StridedFunctor =
99+
ew_cmn_ns::BinaryStridedFunctor<argT1,
100+
argT2,
101+
resT,
102+
IndexerT,
103+
LdexpFunctor<argT1, argT2, resT>>;
104+
105+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
106+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
107+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
108+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
109+
110+
static binary_contig_impl_fn_ptr_t
111+
ldexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
112+
static int ldexp_output_typeid_table[td_ns::num_types][td_ns::num_types];
113+
static binary_strided_impl_fn_ptr_t
114+
ldexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
115+
116+
MACRO_POPULATE_DISPATCH_TABLES(ldexp);
117+
} // namespace impl
118+
119+
void init_ldexp(py::module_ m)
120+
{
121+
using arrayT = dpctl::tensor::usm_ndarray;
122+
using event_vecT = std::vector<sycl::event>;
123+
{
124+
impl::populate_ldexp_dispatch_tables();
125+
using impl::ldexp_contig_dispatch_table;
126+
using impl::ldexp_output_typeid_table;
127+
using impl::ldexp_strided_dispatch_table;
128+
129+
auto ldexp_pyapi = [&](const arrayT &src1, const arrayT &src2,
130+
const arrayT &dst, sycl::queue &exec_q,
131+
const event_vecT &depends = {}) {
132+
return py_int::py_binary_ufunc(
133+
src1, src2, dst, exec_q, depends, ldexp_output_typeid_table,
134+
ldexp_contig_dispatch_table, ldexp_strided_dispatch_table,
135+
// no support of C-contig row with broadcasting in OneMKL
136+
td_ns::NullPtrTable<
137+
impl::
138+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
139+
td_ns::NullPtrTable<
140+
impl::
141+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
142+
};
143+
m.def("_ldexp", ldexp_pyapi, "", py::arg("src1"), py::arg("src2"),
144+
py::arg("dst"), py::arg("sycl_queue"),
145+
py::arg("depends") = py::list());
146+
147+
auto ldexp_result_type_pyapi = [&](const py::dtype &dtype1,
148+
const py::dtype &dtype2) {
149+
return py_int::py_binary_ufunc_result_type(
150+
dtype1, dtype2, ldexp_output_typeid_table);
151+
};
152+
m.def("_ldexp_result_type", ldexp_result_type_pyapi);
153+
}
154+
}
155+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_ldexp(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <sycl/sycl.hpp>
29+
30+
// dpctl tensor headers
31+
#include "utils/math_utils.hpp"
32+
#include "utils/type_utils.hpp"
33+
34+
namespace dpnp::kernels::ldexp
35+
{
36+
template <typename argT1, typename argT2, typename resT>
37+
struct LdexpFunctor
38+
{
39+
using supports_sg_loadstore = typename std::true_type;
40+
using supports_vec = typename std::false_type;
41+
42+
resT operator()(const argT1 &in1, const argT2 &in2) const
43+
{
44+
return sycl::ldexp(in1, in2);
45+
}
46+
};
47+
} // namespace dpnp::kernels::ldexp

0 commit comments

Comments
 (0)