diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 5defd154df..9ca211336a 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -773,14 +773,14 @@ def result_type(*arrays_and_dtypes): raise ValueError( f"Argument {dt} is not supported by the device" ) - res_dt = np.result_type(*dtypes) + res_dt = np.result_type(*dtypes) if dtypes else None res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) for wdt in weak_dtypes: pair = _resolve_weak_types(wdt, res_dt, target_dev) res_dt = np.result_type(*pair) res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) else: - res_dt = np.result_type(*dtypes) + res_dt = np.result_type(*dtypes) if dtypes else None if weak_dtypes: weak_dt_obj = [wdt.get() for wdt in weak_dtypes] res_dt = np.result_type(res_dt, *weak_dt_obj) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 882a001827..f9afcd6f0d 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1013,6 +1013,12 @@ def test_result_type(): assert dpt.result_type(*X).kind == np.result_type(*X_np).kind + dtype = np.dtype(np.float64) + X = [dtype.type(3), dtype.type(3)] + X_np = X + + assert dpt.result_type(*X) == np.result_type(*X_np) + def test_swapaxes_1d(): get_queue_or_skip()