Skip to content

Commit 5682a94

Browse files
committed
Add mkl_vm::i0() extnesion to be used in dpnp.i0
1 parent 6c9037d commit 5682a94

File tree

6 files changed

+180
-0
lines changed

6 files changed

+180
-0
lines changed

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ if(NOT _use_onemkl_interfaces)
5050
${CMAKE_CURRENT_SOURCE_DIR}/fmin.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/fmod.cpp
5252
${CMAKE_CURRENT_SOURCE_DIR}/hypot.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/i0.cpp
5354
${CMAKE_CURRENT_SOURCE_DIR}/ln.cpp
5455
${CMAKE_CURRENT_SOURCE_DIR}/log10.cpp
5556
${CMAKE_CURRENT_SOURCE_DIR}/log1p.cpp

dpnp/backend/extensions/vm/i0.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include "dpctl4pybind11.hpp"
30+
31+
#include "common.hpp"
32+
#include "i0.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
#include "utils/type_utils.hpp"
43+
44+
namespace dpnp::extensions::vm
45+
{
46+
namespace py = pybind11;
47+
namespace py_int = dpnp::extensions::py_internal;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
namespace impl
51+
{
52+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
53+
namespace mkl_vm = oneapi::mkl::vm; // OneMKL namespace with VM functions
54+
namespace tu_ns = dpctl::tensor::type_utils;
55+
56+
/**
57+
* @brief A factory to define pairs of supported types for which
58+
* MKL VM library provides support in oneapi::mkl::vm::i0<T> function.
59+
*
60+
* @tparam T Type of input vector `a` and of result vector `y`.
61+
*/
62+
template <typename T>
63+
struct OutputType
64+
{
65+
using value_type =
66+
typename std::disjunction<td_ns::TypeMapResultEntry<T, double>,
67+
td_ns::TypeMapResultEntry<T, float>,
68+
td_ns::TypeMapResultEntry<T, sycl::half>,
69+
td_ns::DefaultResultEntry<void>>::result_type;
70+
};
71+
72+
template <typename T>
73+
static sycl::event i0_contig_impl(sycl::queue &exec_q,
74+
std::size_t in_n,
75+
const char *in_a,
76+
char *out_y,
77+
const std::vector<sycl::event> &depends)
78+
{
79+
tu_ns::validate_type_for_device<T>(exec_q);
80+
81+
std::int64_t n = static_cast<std::int64_t>(in_n);
82+
const T *a = reinterpret_cast<const T *>(in_a);
83+
84+
using resTy = typename OutputType<T>::value_type;
85+
resTy *y = reinterpret_cast<resTy *>(out_y);
86+
87+
return mkl_vm::i0(exec_q,
88+
n, // number of elements to be calculated
89+
a, // pointer `a` containing input vector of size n
90+
y, // pointer `y` to the output vector of size n
91+
depends);
92+
}
93+
94+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
95+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
96+
97+
static int output_typeid_vector[td_ns::num_types];
98+
static unary_contig_impl_fn_ptr_t contig_dispatch_vector[td_ns::num_types];
99+
100+
MACRO_POPULATE_DISPATCH_VECTORS(i0);
101+
} // namespace impl
102+
103+
void init_i0(py::module_ m)
104+
{
105+
using arrayT = dpctl::tensor::usm_ndarray;
106+
using event_vecT = std::vector<sycl::event>;
107+
108+
impl::populate_dispatch_vectors();
109+
using impl::contig_dispatch_vector;
110+
using impl::output_typeid_vector;
111+
112+
auto i0_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
113+
const arrayT &dst, const event_vecT &depends = {}) {
114+
return py_int::py_unary_ufunc(
115+
src, dst, exec_q, depends, output_typeid_vector,
116+
contig_dispatch_vector,
117+
// no support of strided implementation in OneMKL
118+
td_ns::NullPtrVector<impl::unary_strided_impl_fn_ptr_t>{});
119+
};
120+
m.def("_i0", i0_pyapi,
121+
"Call `i0` function from OneMKL VM library to compute "
122+
"the inverse tangent of vector elements",
123+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
124+
py::arg("depends") = py::list());
125+
126+
auto i0_need_to_call_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
127+
const arrayT &dst) {
128+
return py_internal::need_to_call_unary_ufunc(
129+
exec_q, src, dst, output_typeid_vector, contig_dispatch_vector);
130+
};
131+
m.def("_mkl_i0_to_call", i0_need_to_call_pyapi,
132+
"Check input arguments to answer if `i0` function from "
133+
"OneMKL VM library can be used",
134+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
135+
}
136+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/i0.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::vm
33+
{
34+
void init_i0(py::module_ m);
35+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "fmin.hpp"
5454
#include "fmod.hpp"
5555
#include "hypot.hpp"
56+
#include "i0.hpp"
5657
#include "ln.hpp"
5758
#include "log10.hpp"
5859
#include "log1p.hpp"
@@ -103,6 +104,7 @@ PYBIND11_MODULE(_vm_impl, m)
103104
vm_ns::init_fmin(m);
104105
vm_ns::init_fmod(m);
105106
vm_ns::init_hypot(m);
107+
vm_ns::init_i0(m);
106108
vm_ns::init_ln(m);
107109
vm_ns::init_log10(m);
108110
vm_ns::init_log1p(m);

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,16 @@ def __init__(
567567
result_type_resolver_fn,
568568
unary_dp_impl_fn,
569569
docs,
570+
mkl_fn_to_call=None,
571+
mkl_impl_fn=None,
570572
):
571573
super().__init__(
572574
name,
573575
result_type_resolver_fn,
574576
unary_dp_impl_fn,
575577
docs,
578+
mkl_fn_to_call=mkl_fn_to_call,
579+
mkl_impl_fn=mkl_impl_fn,
576580
)
577581

578582
def __call__(self, x, out=None, order="K"):

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,6 +2681,8 @@ def gradient(f, *varargs, axis=None, edge_order=1):
26812681
ufi._i0_result_type,
26822682
ufi._i0,
26832683
_I0_DOCSTRING,
2684+
mkl_fn_to_call="_mkl_i0_to_call",
2685+
mkl_impl_fn="_i0",
26842686
)
26852687

26862688

0 commit comments

Comments
 (0)