Skip to content

Commit 467e44b

Browse files
committed
Add support of weak_type_resolver keyword to DPNPBinaryFunc class
1 parent 0bb6761 commit 467e44b

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
# *****************************************************************************
2626

2727
import dpctl.tensor as dpt
28+
import dpctl.tensor._tensor_impl as dti
29+
import dpctl.tensor._type_utils as dtu
2830
import numpy
2931
from dpctl.tensor._elementwise_common import (
3032
BinaryElementwiseFunc,
@@ -46,6 +48,7 @@
4648
"DPNPReal",
4749
"DPNPRound",
4850
"DPNPUnaryFunc",
51+
"resolve_weak_types_2nd_arg_int",
4952
]
5053

5154

@@ -244,6 +247,14 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
244247
The function is only called when both arguments of the binary
245248
function require casting, e.g. both arguments of
246249
`dpctl.tensor.logaddexp` are arrays with integral data type.
250+
weak_type_resolver : {callable}, optional
251+
Function to influence type promotion behavior for Python scalar types
252+
of this binary function. The function takes 3 arguments:
253+
o1_dtype - Data type or Python scalar type of the first argument
254+
o2_dtype - Data type or Python scalar type of of the second argument
255+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
256+
evaluation is carried out.
257+
One of `o1_dtype` and `o2_dtype` must be a ``dtype`` instance.
247258
"""
248259

249260
def __init__(
@@ -256,6 +267,7 @@ def __init__(
256267
mkl_impl_fn=None,
257268
binary_inplace_fn=None,
258269
acceptance_fn=None,
270+
weak_type_resolver=None,
259271
):
260272
def _call_func(src1, src2, dst, sycl_queue, depends=None):
261273
"""
@@ -281,6 +293,7 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
281293
docs,
282294
binary_inplace_fn,
283295
acceptance_fn=acceptance_fn,
296+
weak_type_resolver=weak_type_resolver,
284297
)
285298
self.__name__ = "DPNPBinaryFunc"
286299

@@ -606,3 +619,22 @@ def acceptance_fn_subtract(
606619
)
607620
else:
608621
return True
622+
623+
624+
def resolve_weak_types_2nd_arg_int(o1_dtype, o2_dtype, sycl_dev):
625+
"""
626+
The second weak dtype has to be upcasting up to default integer dtype
627+
for a SYCL device where it is possible.
628+
For other cases the default weak types resolving will be applied.
629+
630+
"""
631+
632+
if dtu._is_weak_dtype(o2_dtype):
633+
o1_kind_num = dtu._strong_dtype_num_kind(o1_dtype)
634+
o2_kind_num = dtu._weak_type_num_kind(o2_dtype)
635+
if o2_kind_num < o1_kind_num:
636+
if isinstance(o2_dtype, (dtu.WeakBooleanType, dtu.WeakIntegralType)):
637+
print()
638+
print(o1_dtype, dpt.dtype(dti.default_device_int_type(sycl_dev)))
639+
return o1_dtype, dpt.dtype(dti.default_device_int_type(sycl_dev))
640+
return dtu._resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
acceptance_fn_positive,
7171
acceptance_fn_sign,
7272
acceptance_fn_subtract,
73+
resolve_weak_types_2nd_arg_int,
7374
)
7475
from .dpnp_array import dpnp_array
7576
from .dpnp_utils import call_origin, get_usm_allocations
@@ -2486,6 +2487,7 @@ def gradient(f, *varargs, axis=None, edge_order=1):
24862487
ufi._ldexp_result_type,
24872488
ufi._ldexp,
24882489
_LDEXP_DOCSTRING,
2490+
weak_type_resolver=resolve_weak_types_2nd_arg_int,
24892491
)
24902492

24912493

0 commit comments

Comments
 (0)