2525# *****************************************************************************
2626
2727import dpctl .tensor as dpt
28+ import dpctl .tensor ._tensor_impl as dti
29+ import dpctl .tensor ._type_utils as dtu
2830import numpy
2931from dpctl .tensor ._elementwise_common import (
3032 BinaryElementwiseFunc ,
4648 "DPNPReal" ,
4749 "DPNPRound" ,
4850 "DPNPUnaryFunc" ,
51+ "resolve_weak_types_2nd_arg_int" ,
4952]
5053
5154
@@ -244,6 +247,14 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
244247 The function is only called when both arguments of the binary
245248 function require casting, e.g. both arguments of
246249 `dpctl.tensor.logaddexp` are arrays with integral data type.
250+ weak_type_resolver : {callable}, optional
251+ Function to influence type promotion behavior for Python scalar types
252+ of this binary function. The function takes 3 arguments:
253+ o1_dtype - Data type or Python scalar type of the first argument
254+ o2_dtype - Data type or Python scalar type of of the second argument
255+ sycl_dev - The :class:`dpctl.SyclDevice` where the function
256+ evaluation is carried out.
257+ One of `o1_dtype` and `o2_dtype` must be a ``dtype`` instance.
247258 """
248259
249260 def __init__ (
@@ -256,6 +267,7 @@ def __init__(
256267 mkl_impl_fn = None ,
257268 binary_inplace_fn = None ,
258269 acceptance_fn = None ,
270+ weak_type_resolver = None ,
259271 ):
260272 def _call_func (src1 , src2 , dst , sycl_queue , depends = None ):
261273 """
@@ -281,6 +293,7 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
281293 docs ,
282294 binary_inplace_fn ,
283295 acceptance_fn = acceptance_fn ,
296+ weak_type_resolver = weak_type_resolver ,
284297 )
285298 self .__name__ = "DPNPBinaryFunc"
286299
@@ -606,3 +619,22 @@ def acceptance_fn_subtract(
606619 )
607620 else :
608621 return True
622+
623+
624+ def resolve_weak_types_2nd_arg_int (o1_dtype , o2_dtype , sycl_dev ):
625+ """
626+ The second weak dtype has to be upcasting up to default integer dtype
627+ for a SYCL device where it is possible.
628+ For other cases the default weak types resolving will be applied.
629+
630+ """
631+
632+ if dtu ._is_weak_dtype (o2_dtype ):
633+ o1_kind_num = dtu ._strong_dtype_num_kind (o1_dtype )
634+ o2_kind_num = dtu ._weak_type_num_kind (o2_dtype )
635+ if o2_kind_num < o1_kind_num :
636+ if isinstance (o2_dtype , (dtu .WeakBooleanType , dtu .WeakIntegralType )):
637+ print ()
638+ print (o1_dtype , dpt .dtype (dti .default_device_int_type (sycl_dev )))
639+ return o1_dtype , dpt .dtype (dti .default_device_int_type (sycl_dev ))
640+ return dtu ._resolve_weak_types (o1_dtype , o2_dtype , sycl_dev )
0 commit comments