Skip to content

Commit 75ad9ce

Browse files
author
Vahid Tavanashad
committed
fix an issue with result_type function
1 parent 9eb8f03 commit 75ad9ce

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

dpctl/tensor/_type_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,14 +773,14 @@ def result_type(*arrays_and_dtypes):
773773
raise ValueError(
774774
f"Argument {dt} is not supported by the device"
775775
)
776-
res_dt = np.result_type(*dtypes)
776+
res_dt = np.result_type(*dtypes) if dtypes else None
777777
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
778778
for wdt in weak_dtypes:
779779
pair = _resolve_weak_types(wdt, res_dt, target_dev)
780780
res_dt = np.result_type(*pair)
781781
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
782782
else:
783-
res_dt = np.result_type(*dtypes)
783+
res_dt = np.result_type(*dtypes) if dtypes else None
784784
if weak_dtypes:
785785
weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
786786
res_dt = np.result_type(res_dt, *weak_dt_obj)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,12 @@ def test_result_type():
10131013

10141014
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
10151015

1016+
dtype = np.dtype(np.float64)
1017+
X = [dtype.type(3), dtype.type(3)]
1018+
X_np = X
1019+
1020+
assert dpt.result_type(*X) == np.result_type(*X_np)
1021+
10161022

10171023
def test_swapaxes_1d():
10181024
get_queue_or_skip()

0 commit comments

Comments
 (0)