Skip to content

Commit eb9fc5c

Browse files
committed
Add mkl_vm::inv() extnesion to be used in dpnp.reciprocal
1 parent 5682a94 commit eb9fc5c

File tree

5 files changed

+176
-0
lines changed

5 files changed

+176
-0
lines changed

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ if(NOT _use_onemkl_interfaces)
5151
${CMAKE_CURRENT_SOURCE_DIR}/fmod.cpp
5252
${CMAKE_CURRENT_SOURCE_DIR}/hypot.cpp
5353
${CMAKE_CURRENT_SOURCE_DIR}/i0.cpp
54+
${CMAKE_CURRENT_SOURCE_DIR}/inv.cpp
5455
${CMAKE_CURRENT_SOURCE_DIR}/ln.cpp
5556
${CMAKE_CURRENT_SOURCE_DIR}/log10.cpp
5657
${CMAKE_CURRENT_SOURCE_DIR}/log1p.cpp

dpnp/backend/extensions/vm/inv.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include "dpctl4pybind11.hpp"
30+
31+
#include "common.hpp"
32+
#include "inv.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
#include "utils/type_utils.hpp"
43+
44+
namespace dpnp::extensions::vm
45+
{
46+
namespace py = pybind11;
47+
namespace py_int = dpnp::extensions::py_internal;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
namespace impl
51+
{
52+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
53+
namespace mkl_vm = oneapi::mkl::vm; // OneMKL namespace with VM functions
54+
namespace tu_ns = dpctl::tensor::type_utils;
55+
56+
/**
57+
* @brief A factory to define pairs of supported types for which
58+
* MKL VM library provides support in oneapi::mkl::vm::inv<T> function.
59+
*
60+
* @tparam T Type of input vector `a` and of result vector `y`.
61+
*/
62+
template <typename T>
63+
struct OutputType
64+
{
65+
using value_type =
66+
typename std::disjunction<td_ns::TypeMapResultEntry<T, double>,
67+
td_ns::TypeMapResultEntry<T, float>,
68+
td_ns::TypeMapResultEntry<T, sycl::half>,
69+
td_ns::DefaultResultEntry<void>>::result_type;
70+
};
71+
72+
template <typename T>
73+
static sycl::event inv_contig_impl(sycl::queue &exec_q,
74+
std::size_t in_n,
75+
const char *in_a,
76+
char *out_y,
77+
const std::vector<sycl::event> &depends)
78+
{
79+
tu_ns::validate_type_for_device<T>(exec_q);
80+
81+
std::int64_t n = static_cast<std::int64_t>(in_n);
82+
const T *a = reinterpret_cast<const T *>(in_a);
83+
84+
using resTy = typename OutputType<T>::value_type;
85+
resTy *y = reinterpret_cast<resTy *>(out_y);
86+
87+
return mkl_vm::inv(exec_q,
88+
n, // number of elements to be calculated
89+
a, // pointer `a` containing input vector of size n
90+
y, // pointer `y` to the output vector of size n
91+
depends);
92+
}
93+
94+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
95+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
96+
97+
static int output_typeid_vector[td_ns::num_types];
98+
static unary_contig_impl_fn_ptr_t contig_dispatch_vector[td_ns::num_types];
99+
100+
MACRO_POPULATE_DISPATCH_VECTORS(inv);
101+
} // namespace impl
102+
103+
void init_inv(py::module_ m)
104+
{
105+
using arrayT = dpctl::tensor::usm_ndarray;
106+
using event_vecT = std::vector<sycl::event>;
107+
108+
impl::populate_dispatch_vectors();
109+
using impl::contig_dispatch_vector;
110+
using impl::output_typeid_vector;
111+
112+
auto inv_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
113+
const arrayT &dst, const event_vecT &depends = {}) {
114+
return py_int::py_unary_ufunc(
115+
src, dst, exec_q, depends, output_typeid_vector,
116+
contig_dispatch_vector,
117+
// no support of strided implementation in OneMKL
118+
td_ns::NullPtrVector<impl::unary_strided_impl_fn_ptr_t>{});
119+
};
120+
m.def("_inv", inv_pyapi,
121+
"Call `inv` function from OneMKL VM library to compute "
122+
"the inverse tangent of vector elements",
123+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
124+
py::arg("depends") = py::list());
125+
126+
auto inv_need_to_call_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
127+
const arrayT &dst) {
128+
return py_internal::need_to_call_unary_ufunc(
129+
exec_q, src, dst, output_typeid_vector, contig_dispatch_vector);
130+
};
131+
m.def("_mkl_inv_to_call", inv_need_to_call_pyapi,
132+
"Check input arguments to answer if `inv` function from "
133+
"OneMKL VM library can be used",
134+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
135+
}
136+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/inv.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::vm
33+
{
34+
void init_inv(py::module_ m);
35+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "fmod.hpp"
5555
#include "hypot.hpp"
5656
#include "i0.hpp"
57+
#include "inv.hpp"
5758
#include "ln.hpp"
5859
#include "log10.hpp"
5960
#include "log1p.hpp"
@@ -105,6 +106,7 @@ PYBIND11_MODULE(_vm_impl, m)
105106
vm_ns::init_fmod(m);
106107
vm_ns::init_hypot(m);
107108
vm_ns::init_i0(m);
109+
vm_ns::init_inv(m);
108110
vm_ns::init_ln(m);
109111
vm_ns::init_log10(m);
110112
vm_ns::init_log1p(m);

dpnp/dpnp_iface_trigonometric.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,8 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
21022102
ti._reciprocal_result_type,
21032103
ti._reciprocal,
21042104
_RECIPROCAL_DOCSTRING,
2105+
mkl_fn_to_call="_mkl_inv_to_call",
2106+
mkl_impl_fn="_inv",
21052107
)
21062108

21072109

0 commit comments

Comments
 (0)