Skip to content

Commit a9db942

Browse files
authored
ERF add kernel (#546)
1 parent 43a4b53 commit a9db942

File tree

9 files changed

+98
-9
lines changed

9 files changed

+98
-9
lines changed

dpnp/backend/include/dpnp_gen_1arg_1type_tbl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
MACRO_1ARG_1TYPE_OP(dpnp_conjugate_c, std::conj(input_elem), DPNP_QUEUE.submit(kernel_func))
5959
MACRO_1ARG_1TYPE_OP(dpnp_copy_c, input_elem, DPNP_QUEUE.submit(kernel_func))
60+
MACRO_1ARG_1TYPE_OP(dpnp_erf_c, input_elem, oneapi::mkl::vm::erf(DPNP_QUEUE, size, array1, result))
6061
MACRO_1ARG_1TYPE_OP(dpnp_recip_c,
6162
_DataType(1) / input_elem,
6263
DPNP_QUEUE.submit(kernel_func)) // error: no member named 'recip' in namespace 'cl::sycl'

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum class DPNPFuncName : size_t
9090
DPNP_FN_DOT, /**< Used in numpy.dot() implementation */
9191
DPNP_FN_EIG, /**< Used in numpy.linalg.eig() implementation */
9292
DPNP_FN_EIGVALS, /**< Used in numpy.linalg.eigvals() implementation */
93+
DPNP_FN_ERF, /**< Used in scipy.special.erf implementation */
9394
DPNP_FN_EXP, /**< Used in numpy.exp() implementation */
9495
DPNP_FN_EXP2, /**< Used in numpy.exp2() implementation */
9596
DPNP_FN_EXPM1, /**< Used in numpy.expm1() implementation */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
235235
cgh.parallel_for<class __name__##_kernel<_DataType>>(gws, kernel_parallel_for_func); \
236236
}; \
237237
\
238-
if constexpr (std::is_same<_DataType, double>::value) \
238+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value) \
239239
{ \
240240
event = __operation2__; \
241241
} \
@@ -297,6 +297,11 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
297297
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_copy_c<int>};
298298
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_copy_c<long>};
299299

300+
fmap[DPNPFuncName::DPNP_FN_ERF][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_erf_c<int>};
301+
fmap[DPNPFuncName::DPNP_FN_ERF][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_erf_c<long>};
302+
fmap[DPNPFuncName::DPNP_FN_ERF][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_erf_c<float>};
303+
fmap[DPNPFuncName::DPNP_FN_ERF][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_erf_c<double>};
304+
300305
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_recip_c<double>};
301306
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_recip_c<float>};
302307
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_recip_c<int>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
6363
DPNP_FN_DOT
6464
DPNP_FN_EIG
6565
DPNP_FN_EIGVALS
66+
DPNP_FN_ERF
6667
DPNP_FN_EXP
6768
DPNP_FN_EXP2
6869
DPNP_FN_EXPM1

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ include "dpnp_algo_manipulation.pyx"
6363
include "dpnp_algo_mathematical.pyx"
6464
include "dpnp_algo_searching.pyx"
6565
include "dpnp_algo_sorting.pyx"
66+
include "dpnp_algo_special.pyx"
6667
include "dpnp_algo_statistics.pyx"
6768
include "dpnp_algo_trigonometric.pyx"
6869

dpnp/dpnp_algo/dpnp_algo_special.pyx

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# cython: language_level=3
2+
# -*- coding: utf-8 -*-
3+
# *****************************************************************************
4+
# Copyright (c) 2016-2020, Intel Corporation
5+
# All rights reserved.
6+
#
7+
# Redistribution and use in source and binary forms, with or without
8+
# modification, are permitted provided that the following conditions are met:
9+
# - Redistributions of source code must retain the above copyright notice,
10+
# this list of conditions and the following disclaimer.
11+
# - Redistributions in binary form must reproduce the above copyright notice,
12+
# this list of conditions and the following disclaimer in the documentation
13+
# and/or other materials provided with the distribution.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
19+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
20+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
21+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
22+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
23+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
25+
# THE POSSIBILITY OF SUCH DAMAGE.
26+
# *****************************************************************************
27+
28+
"""Module Backend (Special part)
29+
30+
This module contains interface functions between C backend layer
31+
and the rest of the library
32+
33+
"""
34+
35+
36+
from dpnp.dpnp_utils cimport *
37+
38+
39+
__all__ += [
40+
'dpnp_erf',
41+
]
42+
43+
44+
cpdef dparray dpnp_erf(dparray x1):
45+
return call_fptr_1in_1out(DPNP_FN_ERF, x1, x1.shape)

dpnp/dpnp_iface_libmath.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
4040
"""
4141

42-
4342
import math
4443

44+
from dpnp.dpnp_algo import *
4545
from dpnp.dparray import dparray
46+
from dpnp.dpnp_utils import *
4647

4748

4849
__all__ = [
@@ -76,12 +77,14 @@ def erf(in_array1):
7677
[0.99532227, 0.99853728, 0.99959305, 0.99989938, 0.99997791]
7778
7879
"""
80+
if not use_origin_backend(in_array1):
81+
if not isinstance(in_array1, dparray):
82+
pass
83+
else:
84+
return dpnp_erf(in_array1)
7985

80-
if isinstance(in_array1, dparray):
81-
result = dparray(in_array1.shape, dtype=in_array1.dtype)
82-
for i in range(result.size):
83-
result[i] = math.erf(in_array1[i])
84-
85-
return result
86+
result = dparray(in_array1.shape, dtype=in_array1.dtype)
87+
for i in range(result.size):
88+
result[i] = math.erf(in_array1[i])
8689

87-
return math.erf(in_array1)
90+
return result

tests/test_special.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import math
2+
import dpnp
3+
import numpy
4+
5+
6+
def test_erf():
7+
a = numpy.linspace(2.0, 3.0, num=10)
8+
ia = dpnp.linspace(2.0, 3.0, num=10)
9+
10+
numpy.testing.assert_array_equal(a, ia)
11+
12+
expected = numpy.empty_like(a)
13+
for idx, val in enumerate(a):
14+
expected[idx] = math.erf(val)
15+
16+
result = dpnp.erf(ia)
17+
18+
numpy.testing.assert_array_equal(result, expected)
19+
20+
def test_erf_fallback():
21+
a = numpy.linspace(2.0, 3.0, num=10)
22+
23+
expected = numpy.empty_like(a)
24+
for idx, val in enumerate(a):
25+
expected[idx] = math.erf(val)
26+
27+
result = dpnp.erf(a)
28+
29+
numpy.testing.assert_array_equal(result, expected)

tests_external/skipped_tests_numpy_aborted.tbl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,9 @@ tests/test_pocketfft.py::TestFFT1D::test_axes[fftn]
556556
tests/test_pocketfft.py::TestFFT1D::test_axes[ifftn]
557557
tests/test_pocketfft.py::TestFFT1D::test_axes[rfftn]
558558
tests/test_pocketfft.py::TestFFT1D::test_all_1d_norm_preserving
559+
tests/test_pocketfft.py::TestFFTThreadSafe::test_fft
560+
tests/test_pocketfft.py::TestFFTThreadSafe::test_ifft
561+
tests/test_pocketfft.py::TestFFTThreadSafe::test_rfft
559562
tests/test_randomstate_regression.py::TestRegression::test_logseries_convergence
560563
tests/test_regression.py::TestRegression::test_0d_string_scalar
561564
tests/test_regression.py::TestRegression::test_add_identity

0 commit comments

Comments
 (0)