Skip to content

Commit 9214fea

Browse files
committed
Extend DPNPUnaryTwoOutputsFunc class with support of oneMKL VM callbacks
1 parent 9cb9c4a commit 9214fea

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ class DPNPUnaryFunc(UnaryElementwiseFunc):
9292
corresponds to computational tasks associated with function evaluation.
9393
docs : {str}
9494
Documentation string for the unary function.
95-
mkl_fn_to_call : {callable}
95+
mkl_fn_to_call : {None, str}
9696
Check input arguments to answer if function from OneMKL VM library
9797
can be used.
98-
mkl_impl_fn : {callable}
98+
mkl_impl_fn : {None, str}
9999
Function from OneMKL VM library to call.
100-
acceptance_fn : {callable}, optional
100+
acceptance_fn : {None, callable}, optional
101101
Function to influence type promotion behavior of this unary
102102
function. The function takes 4 arguments:
103103
arg_dtype - Data type of the first argument
@@ -130,7 +130,9 @@ def _call_func(src, dst, sycl_queue, depends=None):
130130
if depends is None:
131131
depends = []
132132

133-
if vmi._is_available() and mkl_fn_to_call is not None:
133+
if vmi._is_available() and not (
134+
mkl_impl_fn is None or mkl_fn_to_call is None
135+
):
134136
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst):
135137
# call pybind11 extension for unary function from OneMKL VM
136138
return getattr(vmi, mkl_impl_fn)(
@@ -231,6 +233,11 @@ class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
231233
corresponds to computational tasks associated with function evaluation.
232234
docs : {str}
233235
Documentation string for the unary function.
236+
mkl_fn_to_call : {None, str}
237+
Check input arguments to answer if function from OneMKL VM library
238+
can be used.
239+
mkl_impl_fn : {None, str}
240+
Function from OneMKL VM library to call.
234241
235242
"""
236243

@@ -240,11 +247,29 @@ def __init__(
240247
result_type_resolver_fn,
241248
unary_dp_impl_fn,
242249
docs,
250+
mkl_fn_to_call=None,
251+
mkl_impl_fn=None,
243252
):
253+
def _call_func(src, dst1, dst2, sycl_queue, depends=None):
254+
"""A callback to register in UnaryElementwiseFunc class."""
255+
256+
if depends is None:
257+
depends = []
258+
259+
if vmi._is_available() and not (
260+
mkl_impl_fn is None or mkl_fn_to_call is None
261+
):
262+
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst1, dst2):
263+
# call pybind11 extension for unary function from OneMKL VM
264+
return getattr(vmi, mkl_impl_fn)(
265+
sycl_queue, src, dst1, dst2, depends
266+
)
267+
return unary_dp_impl_fn(src, dst1, dst2, sycl_queue, depends)
268+
244269
super().__init__(
245270
name,
246271
result_type_resolver_fn,
247-
unary_dp_impl_fn,
272+
_call_func,
248273
docs,
249274
)
250275
self.__name__ = "DPNPUnaryTwoOutputsFunc"
@@ -458,12 +483,12 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
458483
evaluation.
459484
docs : {str}
460485
Documentation string for the unary function.
461-
mkl_fn_to_call : {callable}
486+
mkl_fn_to_call : {None, str}
462487
Check input arguments to answer if function from OneMKL VM library
463488
can be used.
464-
mkl_impl_fn : {callable}
489+
mkl_impl_fn : {None, str}
465490
Function from OneMKL VM library to call.
466-
binary_inplace_fn : {callable}, optional
491+
binary_inplace_fn : {None, callable}, optional
467492
Data-parallel implementation function with signature
468493
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
469494
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
@@ -475,7 +500,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
475500
including async lifetime management of Python arguments,
476501
while the second event corresponds to computational tasks
477502
associated with function evaluation.
478-
acceptance_fn : {callable}, optional
503+
acceptance_fn : {None, callable}, optional
479504
Function to influence type promotion behavior of this binary
480505
function. The function takes 6 arguments:
481506
arg1_dtype - Data type of the first argument
@@ -488,7 +513,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
488513
The function is only called when both arguments of the binary
489514
function require casting, e.g. both arguments of
490515
`dpctl.tensor.logaddexp` are arrays with integral data type.
491-
weak_type_resolver : {callable}, optional
516+
weak_type_resolver : {None, callable}, optional
492517
Function to influence type promotion behavior for Python scalar types
493518
of this binary function. The function takes 3 arguments:
494519
o1_dtype - Data type or Python scalar type of the first argument
@@ -520,7 +545,9 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
520545
if depends is None:
521546
depends = []
522547

523-
if vmi._is_available() and mkl_fn_to_call is not None:
548+
if vmi._is_available() and not (
549+
mkl_impl_fn is None or mkl_fn_to_call is None
550+
):
524551
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src1, src2, dst):
525552
# call pybind11 extension for binary function from OneMKL VM
526553
return getattr(vmi, mkl_impl_fn)(

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3423,6 +3423,8 @@ def interp(x, xp, fp, left=None, right=None, period=None):
34233423
ufi._modf_result_type,
34243424
ufi._modf,
34253425
_MODF_DOCSTRING,
3426+
mkl_fn_to_call="_mkl_modf_to_call",
3427+
mkl_impl_fn="_modf",
34263428
)
34273429

34283430

0 commit comments

Comments
 (0)