Skip to content

Commit 2268f9f

Browse files
committed
Reuse VM implementation of erf from oneMKL
1 parent 6cf7ec1 commit 2268f9f

File tree

4 files changed

+176
-0
lines changed

4 files changed

+176
-0
lines changed

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ if(NOT _use_onemath)
4242
${CMAKE_CURRENT_SOURCE_DIR}/cos.cpp
4343
${CMAKE_CURRENT_SOURCE_DIR}/cosh.cpp
4444
${CMAKE_CURRENT_SOURCE_DIR}/div.cpp
45+
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
4546
${CMAKE_CURRENT_SOURCE_DIR}/exp.cpp
4647
${CMAKE_CURRENT_SOURCE_DIR}/exp2.cpp
4748
${CMAKE_CURRENT_SOURCE_DIR}/expm1.cpp

dpnp/backend/extensions/vm/erf.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <type_traits>
27+
#include <vector>
28+
29+
#include <oneapi/mkl.hpp>
30+
#include <sycl/sycl.hpp>
31+
32+
#include "dpctl4pybind11.hpp"
33+
34+
#include "common.hpp"
35+
#include "erf.hpp"
36+
37+
// include a local copy of elementwise common header from dpctl tensor:
38+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
39+
// TODO: replace by including dpctl header once available
40+
#include "../elementwise_functions/elementwise_functions.hpp"
41+
42+
// dpctl tensor headers
43+
#include "kernels/elementwise_functions/common.hpp"
44+
#include "utils/type_dispatch.hpp"
45+
#include "utils/type_utils.hpp"
46+
47+
namespace dpnp::extensions::vm
48+
{
49+
namespace py = pybind11;
50+
namespace py_int = dpnp::extensions::py_internal;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
namespace impl
54+
{
55+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
56+
namespace mkl_vm = oneapi::mkl::vm; // OneMKL namespace with VM functions
57+
namespace tu_ns = dpctl::tensor::type_utils;
58+
59+
/**
60+
* @brief A factory to define pairs of supported types for which
61+
* MKL VM library provides support in oneapi::mkl::vm::erf<T> function.
62+
*
63+
* @tparam T Type of input vector `a` and of result vector `y`.
64+
*/
65+
template <typename T>
66+
struct OutputType
67+
{
68+
using value_type =
69+
typename std::disjunction<td_ns::TypeMapResultEntry<T, float>,
70+
td_ns::TypeMapResultEntry<T, double>,
71+
td_ns::DefaultResultEntry<void>>::result_type;
72+
};
73+
74+
template <typename T>
75+
static sycl::event erf_contig_impl(sycl::queue &exec_q,
76+
std::size_t in_n,
77+
const char *in_a,
78+
char *out_y,
79+
const std::vector<sycl::event> &depends)
80+
{
81+
tu_ns::validate_type_for_device<T>(exec_q);
82+
83+
std::int64_t n = static_cast<std::int64_t>(in_n);
84+
const T *a = reinterpret_cast<const T *>(in_a);
85+
86+
using resTy = typename OutputType<T>::value_type;
87+
resTy *y = reinterpret_cast<resTy *>(out_y);
88+
89+
return mkl_vm::erf(exec_q,
90+
n, // number of elements to be calculated
91+
a, // pointer `a` containing input vector of size n
92+
y, // pointer `y` to the output vector of size n
93+
depends);
94+
}
95+
96+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
97+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
98+
99+
static int output_typeid_vector[td_ns::num_types];
100+
static unary_contig_impl_fn_ptr_t contig_dispatch_vector[td_ns::num_types];
101+
102+
MACRO_POPULATE_DISPATCH_VECTORS(erf);
103+
} // namespace impl
104+
105+
void init_erf(py::module_ m)
106+
{
107+
using arrayT = dpctl::tensor::usm_ndarray;
108+
using event_vecT = std::vector<sycl::event>;
109+
110+
impl::populate_dispatch_vectors();
111+
using impl::contig_dispatch_vector;
112+
using impl::output_typeid_vector;
113+
114+
auto erf_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
115+
const arrayT &dst, const event_vecT &depends = {}) {
116+
return py_int::py_unary_ufunc(
117+
src, dst, exec_q, depends, output_typeid_vector,
118+
contig_dispatch_vector,
119+
// no support of strided implementation in OneMKL
120+
td_ns::NullPtrVector<impl::unary_strided_impl_fn_ptr_t>{});
121+
};
122+
m.def("_erf", erf_pyapi,
123+
"Call `erf` function from OneMKL VM library to compute "
124+
"the natural (base-e) erfonential of vector elements",
125+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
126+
py::arg("depends") = py::list());
127+
128+
auto erf_need_to_call_pyapi = [&](sycl::queue &exec_q, const arrayT &src,
129+
const arrayT &dst) {
130+
return py_internal::need_to_call_unary_ufunc(
131+
exec_q, src, dst, output_typeid_vector, contig_dispatch_vector);
132+
};
133+
m.def("_mkl_erf_to_call", erf_need_to_call_pyapi,
134+
"Check input arguments to answer if `erf` function from "
135+
"OneMKL VM library can be used",
136+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
137+
}
138+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/erf.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::vm
33+
{
34+
void init_erf(py::module_ m);
35+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "cos.hpp"
4646
#include "cosh.hpp"
4747
#include "div.hpp"
48+
#include "erf.hpp"
4849
#include "exp.hpp"
4950
#include "exp2.hpp"
5051
#include "expm1.hpp"
@@ -97,6 +98,7 @@ PYBIND11_MODULE(_vm_impl, m)
9798
vm_ns::init_cos(m);
9899
vm_ns::init_cosh(m);
99100
vm_ns::init_div(m);
101+
vm_ns::init_erf(m);
100102
vm_ns::init_exp(m);
101103
vm_ns::init_exp2(m);
102104
vm_ns::init_expm1(m);

0 commit comments

Comments
 (0)