Skip to content

Commit 9cb9c4a

Browse files
committed
Add modf iumplementation to VM extension
1 parent 8db6cfc commit 9cb9c4a

File tree

6 files changed

+361
-4
lines changed

6 files changed

+361
-4
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/modf.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,9 @@ struct OutputType
7272
{
7373
using table_type = std::disjunction< // disjunction is C++17
7474
// feature, supported by DPC++
75-
td_int_ns::
76-
TypeMapTwoResultsEntry<T, sycl::half, sycl::half, sycl::half>,
77-
td_int_ns::TypeMapTwoResultsEntry<T, float, float, float>,
78-
td_int_ns::TypeMapTwoResultsEntry<T, double, double, double>,
75+
td_int_ns::TypeMapTwoResultsEntry<T, sycl::half>,
76+
td_int_ns::TypeMapTwoResultsEntry<T, float>,
77+
td_int_ns::TypeMapTwoResultsEntry<T, double>,
7978
td_int_ns::DefaultTwoResultsEntry<void>>;
8079
using value_type1 = typename table_type::result_type1;
8180
using value_type2 = typename table_type::result_type2;

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ if(NOT _use_onemath)
6161
${CMAKE_CURRENT_SOURCE_DIR}/log10.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/log1p.cpp
6363
${CMAKE_CURRENT_SOURCE_DIR}/log2.cpp
64+
${CMAKE_CURRENT_SOURCE_DIR}/modf.cpp
6465
${CMAKE_CURRENT_SOURCE_DIR}/mul.cpp
6566
${CMAKE_CURRENT_SOURCE_DIR}/nextafter.cpp
6667
${CMAKE_CURRENT_SOURCE_DIR}/pow.cpp

dpnp/backend/extensions/vm/common.hpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,115 @@ bool need_to_call_unary_ufunc(sycl::queue &exec_q,
155155
return true;
156156
}
157157

158+
template <typename output_typesT, typename contig_dispatchT>
159+
bool need_to_call_unary_two_outputs_ufunc(
160+
sycl::queue &exec_q,
161+
const dpctl::tensor::usm_ndarray &src,
162+
const dpctl::tensor::usm_ndarray &dst1,
163+
const dpctl::tensor::usm_ndarray &dst2,
164+
const output_typesT &output_type_vec,
165+
const contig_dispatchT &contig_dispatch_vector)
166+
{
167+
// check type_nums
168+
int src_typenum = src.get_typenum();
169+
int dst1_typenum = dst1.get_typenum();
170+
int dst2_typenum = dst2.get_typenum();
171+
172+
const auto &array_types = td_ns::usm_ndarray_types();
173+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
174+
int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
175+
int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
176+
177+
std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
178+
179+
// check that types are supported
180+
if (dst1_typeid != func_output_typeids.first ||
181+
dst2_typeid != func_output_typeids.second)
182+
{
183+
return false;
184+
}
185+
186+
// OneMKL VM functions perform a copy on host if no double type support
187+
if (!exec_q.get_device().has(sycl::aspect::fp64)) {
188+
return false;
189+
}
190+
191+
// check that queues are compatible
192+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst1, dst2})) {
193+
return false;
194+
}
195+
196+
// dimensions must be the same
197+
int src_nd = src.get_ndim();
198+
int dst1_nd = dst1.get_ndim();
199+
int dst2_nd = dst2.get_ndim();
200+
if (src_nd != dst1_nd || src_nd != dst2_nd) {
201+
return false;
202+
}
203+
else if (dst1_nd == 0 || dst2_nd == 0) {
204+
// don't call OneMKL for 0d arrays
205+
return false;
206+
}
207+
208+
// shapes must be the same
209+
const py::ssize_t *src_shape = src.get_shape_raw();
210+
const py::ssize_t *dst1_shape = dst1.get_shape_raw();
211+
const py::ssize_t *dst2_shape = dst2.get_shape_raw();
212+
bool shapes_equal(true);
213+
size_t src_nelems(1);
214+
215+
for (int i = 0; i < src_nd; ++i) {
216+
src_nelems *= static_cast<std::size_t>(src_shape[i]);
217+
shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
218+
(src_shape[i] == dst2_shape[i]);
219+
}
220+
if (!shapes_equal) {
221+
return false;
222+
}
223+
224+
// if nelems is zero, return false
225+
if (src_nelems == 0) {
226+
return false;
227+
}
228+
229+
// ensure that outputs are ample enough to accommodate all elements
230+
auto dst1_offsets = dst1.get_minmax_offsets();
231+
auto dst2_offsets = dst2.get_minmax_offsets();
232+
// destinations must be ample enough to accommodate all elements
233+
{
234+
size_t range1 =
235+
static_cast<size_t>(dst1_offsets.second - dst1_offsets.first);
236+
size_t range2 =
237+
static_cast<size_t>(dst2_offsets.second - dst2_offsets.first);
238+
if ((range1 + 1 < src_nelems) || (range2 + 1 < src_nelems)) {
239+
return false;
240+
}
241+
}
242+
243+
// check memory overlap
244+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
245+
if (overlap(src, dst1) || overlap(src, dst2) || overlap(dst1, dst2)) {
246+
return false;
247+
}
248+
249+
// support only contiguous inputs
250+
bool is_src_c_contig = src.is_c_contiguous();
251+
bool is_dst1_c_contig = dst1.is_c_contiguous();
252+
bool is_dst2_c_contig = dst2.is_c_contiguous();
253+
254+
bool all_c_contig =
255+
(is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
256+
if (!all_c_contig) {
257+
return false;
258+
}
259+
260+
// MKL function is not defined for the type
261+
if (contig_dispatch_vector[src_typeid] == nullptr) {
262+
return false;
263+
}
264+
return true;
265+
}
266+
158267
template <typename output_typesT, typename contig_dispatchT>
159268
bool need_to_call_binary_ufunc(sycl::queue &exec_q,
160269
const dpctl::tensor::usm_ndarray &src1,
@@ -299,6 +408,54 @@ bool need_to_call_binary_ufunc(sycl::queue &exec_q,
299408
ContigFactory>(contig_dispatch_vector); \
300409
};
301410

411+
/**
412+
* @brief A macro used to define factories and a populating unary functions
413+
* with two output arrays to dispatch to a callback with proper OneMKL function
414+
* within VM extension scope.
415+
*/
416+
#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
417+
template <typename fnT, typename T> \
418+
struct ContigFactory \
419+
{ \
420+
fnT get() \
421+
{ \
422+
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
423+
void> || \
424+
std::is_same_v<typename OutputType<T>::value_type2, \
425+
void>) \
426+
{ \
427+
fnT fn = nullptr; \
428+
return fn; \
429+
} \
430+
else { \
431+
fnT fn = __name__##_contig_impl<T>; \
432+
return fn; \
433+
} \
434+
} \
435+
}; \
436+
\
437+
template <typename fnT, typename T> \
438+
struct TypeMapFactory \
439+
{ \
440+
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
441+
std::pair<int, int>> \
442+
get() \
443+
{ \
444+
using rT1 = typename OutputType<T>::value_type1; \
445+
using rT2 = typename OutputType<T>::value_type2; \
446+
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
447+
td_ns::GetTypeid<rT2>{}.get()); \
448+
} \
449+
}; \
450+
\
451+
static void populate_dispatch_vectors(void) \
452+
{ \
453+
ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
454+
output_typeid_vector); \
455+
ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
456+
ContigFactory>(contig_dispatch_vector); \
457+
};
458+
302459
/**
303460
* @brief A macro used to define factories and a populating binary functions
304461
* to dispatch to a callback with proper OneMKL function within VM extension
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#include <complex>
30+
#include <cstddef>
31+
#include <cstdint>
32+
#include <type_traits>
33+
#include <vector>
34+
35+
#include <oneapi/mkl.hpp>
36+
#include <sycl/sycl.hpp>
37+
38+
#include "dpctl4pybind11.hpp"
39+
40+
#include "common.hpp"
41+
#include "modf.hpp"
42+
43+
// include a local copy of elementwise common header from dpctl tensor:
44+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
45+
// TODO: replace by including dpctl header once available
46+
#include "../elementwise_functions/elementwise_functions.hpp"
47+
48+
#include "../elementwise_functions/common.hpp"
49+
#include "../elementwise_functions/type_dispatch_building.hpp"
50+
51+
// dpctl tensor headers
52+
#include "utils/type_dispatch.hpp"
53+
#include "utils/type_utils.hpp"
54+
55+
namespace dpnp::extensions::vm
56+
{
57+
namespace py = pybind11;
58+
namespace py_int = dpnp::extensions::py_internal;
59+
namespace td_ns = dpctl::tensor::type_dispatch;
60+
61+
namespace impl
62+
{
63+
namespace ew_cmn_ns = dpnp::extensions::py_internal::elementwise_common;
64+
namespace mkl_vm = oneapi::mkl::vm; // OneMKL namespace with VM functions
65+
namespace td_int_ns = py_int::type_dispatch;
66+
namespace tu_ns = dpctl::tensor::type_utils;
67+
68+
/**
69+
* @brief A factory to define pairs of supported types for which
70+
* MKL VM library provides support in oneapi::mkl::vm::modf<T> function.
71+
*
72+
* @tparam T Type of input vector `a` and of result vectors `y` and `z`.
73+
*/
74+
template <typename T>
75+
struct OutputType
76+
{
77+
using table_type =
78+
std::disjunction<td_int_ns::TypeMapTwoResultsEntry<T, sycl::half>,
79+
td_int_ns::TypeMapTwoResultsEntry<T, float>,
80+
td_int_ns::TypeMapTwoResultsEntry<T, double>,
81+
td_int_ns::DefaultTwoResultsEntry<void>>;
82+
using value_type1 = typename table_type::result_type1;
83+
using value_type2 = typename table_type::result_type2;
84+
};
85+
86+
template <typename T>
87+
static sycl::event modf_contig_impl(sycl::queue &exec_q,
88+
std::size_t in_n,
89+
const char *in_a,
90+
char *out_y,
91+
char *out_z,
92+
const std::vector<sycl::event> &depends)
93+
{
94+
tu_ns::validate_type_for_device<T>(exec_q);
95+
96+
std::int64_t n = static_cast<std::int64_t>(in_n);
97+
const T *a = reinterpret_cast<const T *>(in_a);
98+
99+
using fractT = typename OutputType<T>::value_type1;
100+
using intT = typename OutputType<T>::value_type2;
101+
fractT *y = reinterpret_cast<fractT *>(out_y);
102+
intT *z = reinterpret_cast<intT *>(out_z);
103+
104+
return mkl_vm::modf(exec_q,
105+
n, // number of elements to be calculated
106+
a, // pointer `a` containing input vector of size n
107+
z, // pointer `z` to the output truncated integer values
108+
y, // pointer `y` to the output remaining fraction parts
109+
depends);
110+
}
111+
112+
using ew_cmn_ns::unary_two_outputs_contig_impl_fn_ptr_t;
113+
using ew_cmn_ns::unary_two_outputs_strided_impl_fn_ptr_t;
114+
115+
static std::pair<int, int> output_typeid_vector[td_ns::num_types];
116+
static unary_two_outputs_contig_impl_fn_ptr_t
117+
contig_dispatch_vector[td_ns::num_types];
118+
119+
MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(modf);
120+
} // namespace impl
121+
122+
void init_modf(py::module_ m)
123+
{
124+
using arrayT = dpctl::tensor::usm_ndarray;
125+
using event_vecT = std::vector<sycl::event>;
126+
127+
impl::populate_dispatch_vectors();
128+
using impl::contig_dispatch_vector;
129+
using impl::output_typeid_vector;
130+
131+
auto modf_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
132+
const arrayT &dst1, const arrayT &dst2,
133+
const event_vecT &depends = {}) {
134+
return py_int::py_unary_two_outputs_ufunc(
135+
src, dst1, dst2, exec_q, depends, output_typeid_vector,
136+
contig_dispatch_vector,
137+
// no support of strided implementation in OneMKL
138+
td_ns::NullPtrVector<
139+
impl::unary_two_outputs_strided_impl_fn_ptr_t>{});
140+
};
141+
m.def("_modf", modf_pyapi,
142+
"Call `modf` function from OneMKL VM library to compute "
143+
"a truncated integer value and the remaining fraction part "
144+
"for each vector elements",
145+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst1"),
146+
py::arg("dst2"), py::arg("depends") = py::list());
147+
148+
auto modf_need_to_call_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
149+
const arrayT &dst1, const arrayT &dst2) {
150+
return py_internal::need_to_call_unary_two_outputs_ufunc(
151+
exec_q, src, dst1, dst2, output_typeid_vector,
152+
contig_dispatch_vector);
153+
};
154+
m.def("_mkl_modf_to_call", modf_need_to_call_pyapi,
155+
"Check input arguments to answer if `modf` function from "
156+
"OneMKL VM library can be used",
157+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst1"),
158+
py::arg("dst2"));
159+
}
160+
} // namespace dpnp::extensions::vm
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <pybind11/pybind11.h>
32+
33+
namespace py = pybind11;
34+
35+
namespace dpnp::extensions::vm
36+
{
37+
void init_modf(py::module_ m);
38+
} // namespace dpnp::extensions::vm

0 commit comments

Comments
 (0)