Skip to content

Commit aad63d9

Browse files
committed
Extend DPNPUnaryTwoOutputsFunc class with support of oneMKL VM callbacks
1 parent 14da88a commit aad63d9

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ class DPNPUnaryFunc(UnaryElementwiseFunc):
9191
corresponds to computational tasks associated with function evaluation.
9292
docs : {str}
9393
Documentation string for the unary function.
94-
mkl_fn_to_call : {callable}
94+
mkl_fn_to_call : {None, str}
9595
Check input arguments to answer if function from OneMKL VM library
9696
can be used.
97-
mkl_impl_fn : {callable}
97+
mkl_impl_fn : {None, str}
9898
Function from OneMKL VM library to call.
99-
acceptance_fn : {callable}, optional
99+
acceptance_fn : {None, callable}, optional
100100
Function to influence type promotion behavior of this unary
101101
function. The function takes 4 arguments:
102102
arg_dtype - Data type of the first argument
@@ -129,7 +129,9 @@ def _call_func(src, dst, sycl_queue, depends=None):
129129
if depends is None:
130130
depends = []
131131

132-
if vmi._is_available() and mkl_fn_to_call is not None:
132+
if vmi._is_available() and not (
133+
mkl_impl_fn is None or mkl_fn_to_call is None
134+
):
133135
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst):
134136
# call pybind11 extension for unary function from OneMKL VM
135137
return getattr(vmi, mkl_impl_fn)(
@@ -232,6 +234,11 @@ class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
232234
corresponds to computational tasks associated with function evaluation.
233235
docs : {str}
234236
Documentation string for the unary function.
237+
mkl_fn_to_call : {None, str}
238+
Check input arguments to answer if function from OneMKL VM library
239+
can be used.
240+
mkl_impl_fn : {None, str}
241+
Function from OneMKL VM library to call.
235242
236243
"""
237244

@@ -241,11 +248,29 @@ def __init__(
241248
result_type_resolver_fn,
242249
unary_dp_impl_fn,
243250
docs,
251+
mkl_fn_to_call=None,
252+
mkl_impl_fn=None,
244253
):
254+
def _call_func(src, dst1, dst2, sycl_queue, depends=None):
255+
"""A callback to register in UnaryElementwiseFunc class."""
256+
257+
if depends is None:
258+
depends = []
259+
260+
if vmi._is_available() and not (
261+
mkl_impl_fn is None or mkl_fn_to_call is None
262+
):
263+
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst1, dst2):
264+
# call pybind11 extension for unary function from OneMKL VM
265+
return getattr(vmi, mkl_impl_fn)(
266+
sycl_queue, src, dst1, dst2, depends
267+
)
268+
return unary_dp_impl_fn(src, dst1, dst2, sycl_queue, depends)
269+
245270
super().__init__(
246271
name,
247272
result_type_resolver_fn,
248-
unary_dp_impl_fn,
273+
_call_func,
249274
docs,
250275
)
251276
self.__name__ = "DPNPUnaryTwoOutputsFunc"
@@ -459,12 +484,12 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
459484
evaluation.
460485
docs : {str}
461486
Documentation string for the unary function.
462-
mkl_fn_to_call : {callable}
487+
mkl_fn_to_call : {None, str}
463488
Check input arguments to answer if function from OneMKL VM library
464489
can be used.
465-
mkl_impl_fn : {callable}
490+
mkl_impl_fn : {None, str}
466491
Function from OneMKL VM library to call.
467-
binary_inplace_fn : {callable}, optional
492+
binary_inplace_fn : {None, callable}, optional
468493
Data-parallel implementation function with signature
469494
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
470495
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
@@ -476,7 +501,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
476501
including async lifetime management of Python arguments,
477502
while the second event corresponds to computational tasks
478503
associated with function evaluation.
479-
acceptance_fn : {callable}, optional
504+
acceptance_fn : {None, callable}, optional
480505
Function to influence type promotion behavior of this binary
481506
function. The function takes 6 arguments:
482507
arg1_dtype - Data type of the first argument
@@ -489,7 +514,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
489514
The function is only called when both arguments of the binary
490515
function require casting, e.g. both arguments of
491516
`dpctl.tensor.logaddexp` are arrays with integral data type.
492-
weak_type_resolver : {callable}, optional
517+
weak_type_resolver : {None, callable}, optional
493518
Function to influence type promotion behavior for Python scalar types
494519
of this binary function. The function takes 3 arguments:
495520
o1_dtype - Data type or Python scalar type of the first argument
@@ -521,7 +546,9 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
521546
if depends is None:
522547
depends = []
523548

524-
if vmi._is_available() and mkl_fn_to_call is not None:
549+
if vmi._is_available() and not (
550+
mkl_impl_fn is None or mkl_fn_to_call is None
551+
):
525552
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src1, src2, dst):
526553
# call pybind11 extension for binary function from OneMKL VM
527554
return getattr(vmi, mkl_impl_fn)(

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3422,6 +3422,8 @@ def interp(x, xp, fp, left=None, right=None, period=None):
34223422
ufi._modf_result_type,
34233423
ufi._modf,
34243424
_MODF_DOCSTRING,
3425+
mkl_fn_to_call="_mkl_modf_to_call",
3426+
mkl_impl_fn="_modf",
34253427
)
34263428

34273429

0 commit comments

Comments
 (0)