@@ -91,12 +91,12 @@ class DPNPUnaryFunc(UnaryElementwiseFunc):
9191 corresponds to computational tasks associated with function evaluation.
9292 docs : {str}
9393 Documentation string for the unary function.
94- mkl_fn_to_call : {callable }
94+ mkl_fn_to_call : {None, str }
9595 Check input arguments to answer if function from OneMKL VM library
9696 can be used.
97- mkl_impl_fn : {callable }
97+ mkl_impl_fn : {None, str }
9898 Function from OneMKL VM library to call.
99- acceptance_fn : {callable}, optional
99+ acceptance_fn : {None, callable}, optional
100100 Function to influence type promotion behavior of this unary
101101 function. The function takes 4 arguments:
102102 arg_dtype - Data type of the first argument
@@ -129,7 +129,9 @@ def _call_func(src, dst, sycl_queue, depends=None):
129129 if depends is None :
130130 depends = []
131131
132- if vmi ._is_available () and mkl_fn_to_call is not None :
132+ if vmi ._is_available () and not (
133+ mkl_impl_fn is None or mkl_fn_to_call is None
134+ ):
133135 if getattr (vmi , mkl_fn_to_call )(sycl_queue , src , dst ):
134136 # call pybind11 extension for unary function from OneMKL VM
135137 return getattr (vmi , mkl_impl_fn )(
@@ -232,6 +234,11 @@ class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
232234 corresponds to computational tasks associated with function evaluation.
233235 docs : {str}
234236 Documentation string for the unary function.
237+ mkl_fn_to_call : {None, str}
238+ Check input arguments to answer if function from OneMKL VM library
239+ can be used.
240+ mkl_impl_fn : {None, str}
241+ Function from OneMKL VM library to call.
235242
236243 """
237244
@@ -241,11 +248,29 @@ def __init__(
241248 result_type_resolver_fn ,
242249 unary_dp_impl_fn ,
243250 docs ,
251+ mkl_fn_to_call = None ,
252+ mkl_impl_fn = None ,
244253 ):
254+ def _call_func (src , dst1 , dst2 , sycl_queue , depends = None ):
255+ """A callback to register in UnaryElementwiseFunc class."""
256+
257+ if depends is None :
258+ depends = []
259+
260+ if vmi ._is_available () and not (
261+ mkl_impl_fn is None or mkl_fn_to_call is None
262+ ):
263+ if getattr (vmi , mkl_fn_to_call )(sycl_queue , src , dst1 , dst2 ):
264+ # call pybind11 extension for unary function from OneMKL VM
265+ return getattr (vmi , mkl_impl_fn )(
266+ sycl_queue , src , dst1 , dst2 , depends
267+ )
268+ return unary_dp_impl_fn (src , dst1 , dst2 , sycl_queue , depends )
269+
245270 super ().__init__ (
246271 name ,
247272 result_type_resolver_fn ,
248- unary_dp_impl_fn ,
273+ _call_func ,
249274 docs ,
250275 )
251276 self .__name__ = "DPNPUnaryTwoOutputsFunc"
@@ -459,12 +484,12 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
459484 evaluation.
460485 docs : {str}
461486 Documentation string for the unary function.
462- mkl_fn_to_call : {callable }
487+ mkl_fn_to_call : {None, str }
463488 Check input arguments to answer if function from OneMKL VM library
464489 can be used.
465- mkl_impl_fn : {callable }
490+ mkl_impl_fn : {None, str }
466491 Function from OneMKL VM library to call.
467- binary_inplace_fn : {callable}, optional
492+ binary_inplace_fn : {None, callable}, optional
468493 Data-parallel implementation function with signature
469494 `impl_fn(src: usm_ndarray, dst: usm_ndarray,
470495 sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
@@ -476,7 +501,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
476501 including async lifetime management of Python arguments,
477502 while the second event corresponds to computational tasks
478503 associated with function evaluation.
479- acceptance_fn : {callable}, optional
504+ acceptance_fn : {None, callable}, optional
480505 Function to influence type promotion behavior of this binary
481506 function. The function takes 6 arguments:
482507 arg1_dtype - Data type of the first argument
@@ -489,7 +514,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
489514 The function is only called when both arguments of the binary
490515 function require casting, e.g. both arguments of
491516 `dpctl.tensor.logaddexp` are arrays with integral data type.
492- weak_type_resolver : {callable}, optional
517+ weak_type_resolver : {None, callable}, optional
493518 Function to influence type promotion behavior for Python scalar types
494519 of this binary function. The function takes 3 arguments:
495520 o1_dtype - Data type or Python scalar type of the first argument
@@ -521,7 +546,9 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
521546 if depends is None :
522547 depends = []
523548
524- if vmi ._is_available () and mkl_fn_to_call is not None :
549+ if vmi ._is_available () and not (
550+ mkl_impl_fn is None or mkl_fn_to_call is None
551+ ):
525552 if getattr (vmi , mkl_fn_to_call )(sycl_queue , src1 , src2 , dst ):
526553 # call pybind11 extension for binary function from OneMKL VM
527554 return getattr (vmi , mkl_impl_fn )(
0 commit comments