@@ -92,12 +92,12 @@ class DPNPUnaryFunc(UnaryElementwiseFunc):
9292 corresponds to computational tasks associated with function evaluation.
9393 docs : {str}
9494 Documentation string for the unary function.
95- mkl_fn_to_call : {callable }
95+ mkl_fn_to_call : {None, str }
9696 Check input arguments to answer if function from OneMKL VM library
9797 can be used.
98- mkl_impl_fn : {callable }
98+ mkl_impl_fn : {None, str }
9999 Function from OneMKL VM library to call.
100- acceptance_fn : {callable}, optional
100+ acceptance_fn : {None, callable}, optional
101101 Function to influence type promotion behavior of this unary
102102 function. The function takes 4 arguments:
103103 arg_dtype - Data type of the first argument
@@ -130,7 +130,9 @@ def _call_func(src, dst, sycl_queue, depends=None):
130130 if depends is None :
131131 depends = []
132132
133- if vmi ._is_available () and mkl_fn_to_call is not None :
133+ if vmi ._is_available () and not (
134+ mkl_impl_fn is None or mkl_fn_to_call is None
135+ ):
134136 if getattr (vmi , mkl_fn_to_call )(sycl_queue , src , dst ):
135137 # call pybind11 extension for unary function from OneMKL VM
136138 return getattr (vmi , mkl_impl_fn )(
@@ -231,6 +233,11 @@ class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
231233 corresponds to computational tasks associated with function evaluation.
232234 docs : {str}
233235 Documentation string for the unary function.
236+ mkl_fn_to_call : {None, str}
237+ Check input arguments to answer if function from OneMKL VM library
238+ can be used.
239+ mkl_impl_fn : {None, str}
240+ Function from OneMKL VM library to call.
234241
235242 """
236243
@@ -240,11 +247,29 @@ def __init__(
240247 result_type_resolver_fn ,
241248 unary_dp_impl_fn ,
242249 docs ,
250+ mkl_fn_to_call = None ,
251+ mkl_impl_fn = None ,
243252 ):
253+ def _call_func (src , dst1 , dst2 , sycl_queue , depends = None ):
254+ """A callback to register in UnaryElementwiseFunc class."""
255+
256+ if depends is None :
257+ depends = []
258+
259+ if vmi ._is_available () and not (
260+ mkl_impl_fn is None or mkl_fn_to_call is None
261+ ):
262+ if getattr (vmi , mkl_fn_to_call )(sycl_queue , src , dst1 , dst2 ):
263+ # call pybind11 extension for unary function from OneMKL VM
264+ return getattr (vmi , mkl_impl_fn )(
265+ sycl_queue , src , dst1 , dst2 , depends
266+ )
267+ return unary_dp_impl_fn (src , dst1 , dst2 , sycl_queue , depends )
268+
244269 super ().__init__ (
245270 name ,
246271 result_type_resolver_fn ,
247- unary_dp_impl_fn ,
272+ _call_func ,
248273 docs ,
249274 )
250275 self .__name__ = "DPNPUnaryTwoOutputsFunc"
@@ -458,12 +483,12 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
458483 evaluation.
459484 docs : {str}
460485 Documentation string for the unary function.
461- mkl_fn_to_call : {callable }
486+ mkl_fn_to_call : {None, str }
462487 Check input arguments to answer if function from OneMKL VM library
463488 can be used.
464- mkl_impl_fn : {callable }
489+ mkl_impl_fn : {None, str }
465490 Function from OneMKL VM library to call.
466- binary_inplace_fn : {callable}, optional
491+ binary_inplace_fn : {None, callable}, optional
467492 Data-parallel implementation function with signature
468493 `impl_fn(src: usm_ndarray, dst: usm_ndarray,
469494 sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
@@ -475,7 +500,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
475500 including async lifetime management of Python arguments,
476501 while the second event corresponds to computational tasks
477502 associated with function evaluation.
478- acceptance_fn : {callable}, optional
503+ acceptance_fn : {None, callable}, optional
479504 Function to influence type promotion behavior of this binary
480505 function. The function takes 6 arguments:
481506 arg1_dtype - Data type of the first argument
@@ -488,7 +513,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
488513 The function is only called when both arguments of the binary
489514 function require casting, e.g. both arguments of
490515 `dpctl.tensor.logaddexp` are arrays with integral data type.
491- weak_type_resolver : {callable}, optional
516+ weak_type_resolver : {None, callable}, optional
492517 Function to influence type promotion behavior for Python scalar types
493518 of this binary function. The function takes 3 arguments:
494519 o1_dtype - Data type or Python scalar type of the first argument
@@ -520,7 +545,9 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
520545 if depends is None :
521546 depends = []
522547
523- if vmi ._is_available () and mkl_fn_to_call is not None :
548+ if vmi ._is_available () and not (
549+ mkl_impl_fn is None or mkl_fn_to_call is None
550+ ):
524551 if getattr (vmi , mkl_fn_to_call )(sycl_queue , src1 , src2 , dst ):
525552 # call pybind11 extension for binary function from OneMKL VM
526553 return getattr (vmi , mkl_impl_fn )(
0 commit comments