diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c06affac..c7092ad4e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +* Change to `tensor.result_type` to raise `ValueError` unless at least one argument is a `tensor.usm_ndarray` instance [gh-1876](https://github.com/IntelPython/dpctl/pull/1876) + ### Maintenance * Update black version used in Python code style workflow [gh-1828](https://github.com/IntelPython/dpctl/pull/1828) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 4a6c32c3f4..545aa21fb3 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -35,6 +35,7 @@ _all_data_types, _can_cast, _is_weak_dtype, + _result_type_fn_impl, _strong_dtype_num_kind, _to_device_supported_dtype, _weak_type_num_kind, @@ -95,7 +96,13 @@ def _resolve_two_weak_types(o1_dtype, o2_dtype, dev): def _where_result_type(dt1, dt2, dev): - res_dtype = dpt.result_type(dt1, dt2) + res_dtype = _result_type_fn_impl( + ( + dt1, + dt2, + ), + sycl_device=dev, + ) fp16 = dev.has_aspect_fp16 fp64 = dev.has_aspect_fp64 diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 5defd154df..0b23b8155d 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -21,6 +21,13 @@ import dpctl.tensor._tensor_impl as ti +def _supported_dtype(dtypes): + for dtype in dtypes: + if dtype.char not in "?bBhHiIlLqQefdFD": + raise ValueError(f"Dpctl doesn't support dtype {dtype}.") + return True + + def _all_data_types(_fp16, _fp64): _non_fp_types = [ dpt.bool, @@ -708,26 +715,11 @@ def can_cast(from_, to, /, *, casting="safe") -> bool: return _can_cast(dtype_from, dtype_to, True, True, casting=casting) -def result_type(*arrays_and_dtypes): - """ - result_type(*arrays_and_dtypes) - - Returns the dtype that results from applying the Type Promotion Rules to \ - the arguments. - - Args: - arrays_and_dtypes (Union[usm_ndarray, dtype]): - An arbitrary length sequence of usm_ndarray objects or dtypes. - - Returns: - dtype: - The dtype resulting from an operation involving the - input arrays and dtypes. - """ +def _result_type_fn_impl(arrays_and_dtypes_tuple, sycl_device=None): dtypes = [] - devices = [] + devices = [] if sycl_device is None else [sycl_device] weak_dtypes = [] - for arg_i in arrays_and_dtypes: + for arg_i in arrays_and_dtypes_tuple: if isinstance(arg_i, dpt.usm_ndarray): devices.append(arg_i.sycl_device) dtypes.append(arg_i.dtype) @@ -766,6 +758,10 @@ def result_type(*arrays_and_dtypes): has_fp64 = d.has_aspect_fp64 target_dev = d inspected = True + else: + raise ValueError( + "At least one argument must have type `dpctl.tensor.usm_ndarray`" + ) if not (has_fp16 and has_fp64): for dt in dtypes: @@ -788,6 +784,25 @@ def result_type(*arrays_and_dtypes): return res_dt +def result_type(*arrays_and_dtypes): + """ + result_type(*arrays_and_dtypes) + + Returns the dtype that results from applying the Type Promotion Rules to \ + the arguments. + + Args: + arrays_and_dtypes (Union[usm_ndarray, dtype]): + An arbitrary length sequence of usm_ndarray objects or dtypes. + + Returns: + dtype: + The dtype resulting from an operation involving the + input arrays and dtypes. + """ + return _result_type_fn_impl(arrays_and_dtypes) + + def iinfo(dtype, /): """iinfo(dtype) @@ -855,13 +870,6 @@ def finfo(dtype, /): return finfo_object(dtype) -def _supported_dtype(dtypes): - for dtype in dtypes: - if dtype.char not in "?bBhHiIlLqQefdFD": - raise ValueError(f"Dpctl doesn't support dtype {dtype}.") - return True - - def isdtype(dtype, kind): """isdtype(dtype, kind) diff --git a/dpctl/tests/helper/_helper.py b/dpctl/tests/helper/_helper.py index 654c197ddc..962e5f1cb6 100644 --- a/dpctl/tests/helper/_helper.py +++ b/dpctl/tests/helper/_helper.py @@ -55,11 +55,14 @@ def skip_if_dtype_not_supported(dt, q_or_dev): import dpctl.tensor as dpt dt = dpt.dtype(dt) - if type(q_or_dev) is dpctl.SyclQueue: + if isinstance(q_or_dev, dpctl.SyclQueue): dev = q_or_dev.sycl_device - elif type(q_or_dev) is dpctl.SyclDevice: - dev = q_or_dev else: + dev = q_or_dev + + if not hasattr(dev, "has_aspect_fp16") or not hasattr( + dev, "has_aspect_fp64" + ): raise TypeError( "Expected dpctl.SyclQueue or dpctl.SyclDevice, " f"got {type(q_or_dev)}" diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 882a001827..3e29c88153 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -998,11 +998,6 @@ def test_result_type(): assert dpt.result_type(*X) == np.result_type(*X_np) - X = [dpt.int32, "int64", 2] - X_np = [np.int32, "int64", 2] - - assert dpt.result_type(*X) == np.result_type(*X_np) - X = [usm_ar, dpt.int32, "int64", 2.0] X_np = [np_ar, np.int32, "int64", 2.0] @@ -1013,6 +1008,10 @@ def test_result_type(): assert dpt.result_type(*X).kind == np.result_type(*X_np).kind + X = [dpt.int32, "int64", 2] + with pytest.raises(ValueError): + dpt.result_type(*X) + def test_swapaxes_1d(): get_queue_or_skip() diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index a646f4cde1..350cb5af86 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -48,6 +48,7 @@ class mock_device: def __init__(self, fp16, fp64): + self.name = "Mock device" self.has_aspect_fp16 = fp16 self.has_aspect_fp64 = fp64 @@ -101,14 +102,17 @@ def test_where_result_types(dt1, dt2, fp16, fp64): dev = mock_device(fp16, fp64) dt1 = dpt.dtype(dt1) + skip_if_dtype_not_supported(dt1, dev) dt2 = dpt.dtype(dt2) + skip_if_dtype_not_supported(dt2, dev) + res_t = _where_result_type(dt1, dt2, dev) if fp16 and fp64: - assert res_t == dpt.result_type(dt1, dt2) + assert res_t == np.result_type(dt1, dt2) else: if res_t: - assert res_t.kind == dpt.result_type(dt1, dt2).kind + assert res_t.kind == np.result_type(dt1, dt2).kind else: # some illegal cases are covered above, but # this guarantees that _where_result_type