Skip to content

Commit 94d8624

Browse files
dpt.result_type changed to raise unless at least one argument is usm_ndarray
Implementation of dpt.where was calling `dpt.result_type` with just dtypes, although it has the device available. Implementation of result_type was factored out into a helper function which can accommodate know device, and reused by `dpt.result_type` and `_where_result_type` function. Test files were tweaked to avoid calling `dpt.result_type` with just dtypes.
1 parent 9eb8f03 commit 94d8624

File tree

5 files changed

+58
-36
lines changed

5 files changed

+58
-36
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_all_data_types,
3636
_can_cast,
3737
_is_weak_dtype,
38+
_result_type_fn_impl,
3839
_strong_dtype_num_kind,
3940
_to_device_supported_dtype,
4041
_weak_type_num_kind,
@@ -95,7 +96,14 @@ def _resolve_two_weak_types(o1_dtype, o2_dtype, dev):
9596

9697

9798
def _where_result_type(dt1, dt2, dev):
98-
res_dtype = dpt.result_type(dt1, dt2)
99+
print(dev)
100+
res_dtype = _result_type_fn_impl(
101+
(
102+
dt1,
103+
dt2,
104+
),
105+
sycl_device=dev,
106+
)
99107
fp16 = dev.has_aspect_fp16
100108
fp64 = dev.has_aspect_fp64
101109

dpctl/tensor/_type_utils.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
import dpctl.tensor._tensor_impl as ti
2222

2323

24+
def _supported_dtype(dtypes):
25+
for dtype in dtypes:
26+
if dtype.char not in "?bBhHiIlLqQefdFD":
27+
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
28+
return True
29+
30+
2431
def _all_data_types(_fp16, _fp64):
2532
_non_fp_types = [
2633
dpt.bool,
@@ -708,26 +715,11 @@ def can_cast(from_, to, /, *, casting="safe") -> bool:
708715
return _can_cast(dtype_from, dtype_to, True, True, casting=casting)
709716

710717

711-
def result_type(*arrays_and_dtypes):
712-
"""
713-
result_type(*arrays_and_dtypes)
714-
715-
Returns the dtype that results from applying the Type Promotion Rules to \
716-
the arguments.
717-
718-
Args:
719-
arrays_and_dtypes (Union[usm_ndarray, dtype]):
720-
An arbitrary length sequence of usm_ndarray objects or dtypes.
721-
722-
Returns:
723-
dtype:
724-
The dtype resulting from an operation involving the
725-
input arrays and dtypes.
726-
"""
718+
def _result_type_fn_impl(arrays_and_dtypes_tuple, sycl_device=None):
727719
dtypes = []
728-
devices = []
720+
devices = [] if sycl_device is None else [sycl_device]
729721
weak_dtypes = []
730-
for arg_i in arrays_and_dtypes:
722+
for arg_i in arrays_and_dtypes_tuple:
731723
if isinstance(arg_i, dpt.usm_ndarray):
732724
devices.append(arg_i.sycl_device)
733725
dtypes.append(arg_i.dtype)
@@ -766,6 +758,10 @@ def result_type(*arrays_and_dtypes):
766758
has_fp64 = d.has_aspect_fp64
767759
target_dev = d
768760
inspected = True
761+
else:
762+
raise ValueError(
763+
"At least only argument must have type `dpctl.tensor.usm_ndarray`"
764+
)
769765

770766
if not (has_fp16 and has_fp64):
771767
for dt in dtypes:
@@ -788,6 +784,25 @@ def result_type(*arrays_and_dtypes):
788784
return res_dt
789785

790786

787+
def result_type(*arrays_and_dtypes):
788+
"""
789+
result_type(*arrays_and_dtypes)
790+
791+
Returns the dtype that results from applying the Type Promotion Rules to \
792+
the arguments.
793+
794+
Args:
795+
arrays_and_dtypes (Union[usm_ndarray, dtype]):
796+
An arbitrary length sequence of usm_ndarray objects or dtypes.
797+
798+
Returns:
799+
dtype:
800+
The dtype resulting from an operation involving the
801+
input arrays and dtypes.
802+
"""
803+
return _result_type_fn_impl(arrays_and_dtypes)
804+
805+
791806
def iinfo(dtype, /):
792807
"""iinfo(dtype)
793808
@@ -855,13 +870,6 @@ def finfo(dtype, /):
855870
return finfo_object(dtype)
856871

857872

858-
def _supported_dtype(dtypes):
859-
for dtype in dtypes:
860-
if dtype.char not in "?bBhHiIlLqQefdFD":
861-
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
862-
return True
863-
864-
865873
def isdtype(dtype, kind):
866874
"""isdtype(dtype, kind)
867875

dpctl/tests/helper/_helper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,14 @@ def skip_if_dtype_not_supported(dt, q_or_dev):
5555
import dpctl.tensor as dpt
5656

5757
dt = dpt.dtype(dt)
58-
if type(q_or_dev) is dpctl.SyclQueue:
58+
if isinstance(q_or_dev, dpctl.SyclQueue):
5959
dev = q_or_dev.sycl_device
60-
elif type(q_or_dev) is dpctl.SyclDevice:
61-
dev = q_or_dev
6260
else:
61+
dev = q_or_dev
62+
63+
if not hasattr(dev, "has_aspect_fp16") or not hasattr(
64+
dev, "has_aspect_fp64"
65+
):
6366
raise TypeError(
6467
"Expected dpctl.SyclQueue or dpctl.SyclDevice, "
6568
f"got {type(q_or_dev)}"

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -998,11 +998,6 @@ def test_result_type():
998998

999999
assert dpt.result_type(*X) == np.result_type(*X_np)
10001000

1001-
X = [dpt.int32, "int64", 2]
1002-
X_np = [np.int32, "int64", 2]
1003-
1004-
assert dpt.result_type(*X) == np.result_type(*X_np)
1005-
10061001
X = [usm_ar, dpt.int32, "int64", 2.0]
10071002
X_np = [np_ar, np.int32, "int64", 2.0]
10081003

@@ -1013,6 +1008,10 @@ def test_result_type():
10131008

10141009
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
10151010

1011+
X = [dpt.int32, "int64", 2]
1012+
with pytest.raises(ValueError):
1013+
dpt.result_type(*X)
1014+
10161015

10171016
def test_swapaxes_1d():
10181017
get_queue_or_skip()

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
class mock_device:
5050
def __init__(self, fp16, fp64):
51+
self.name = "Mock device"
5152
self.has_aspect_fp16 = fp16
5253
self.has_aspect_fp64 = fp64
5354

@@ -101,14 +102,17 @@ def test_where_result_types(dt1, dt2, fp16, fp64):
101102
dev = mock_device(fp16, fp64)
102103

103104
dt1 = dpt.dtype(dt1)
105+
skip_if_dtype_not_supported(dt1, dev)
104106
dt2 = dpt.dtype(dt2)
107+
skip_if_dtype_not_supported(dt2, dev)
108+
105109
res_t = _where_result_type(dt1, dt2, dev)
106110

107111
if fp16 and fp64:
108-
assert res_t == dpt.result_type(dt1, dt2)
112+
assert res_t == np.result_type(dt1, dt2)
109113
else:
110114
if res_t:
111-
assert res_t.kind == dpt.result_type(dt1, dt2).kind
115+
assert res_t.kind == np.result_type(dt1, dt2).kind
112116
else:
113117
# some illegal cases are covered above, but
114118
# this guarantees that _where_result_type

0 commit comments

Comments
 (0)