diff --git a/CHANGELOG.md b/CHANGELOG.md index 74b3569030..cb45db859d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.18.2] - Nov. XX, 2024 + +### Fixed +* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1904](https://github.com/IntelPython/dpctl/pull/1904) + ## [0.18.1] - Oct. 11, 2024 ### Changed diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 5defd154df..f279052f94 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes): target_dev = d inspected = True + if not dtypes and weak_dtypes: + dtypes.append(weak_dtypes[0].get()) + if not (has_fp16 and has_fp64): for dt in dtypes: if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 17262e2141..e9b121f8a1 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -15,6 +15,8 @@ # limitations under the License. +import itertools + import numpy as np import pytest from numpy.testing import assert_, assert_array_equal, assert_raises_regex @@ -1531,3 +1533,26 @@ def test_repeat_0_size(): res = dpt.repeat(x, repetitions, axis=1) axis_sz = 2 * x.shape[1] assert res.shape == (0, axis_sz, 0) + + +def test_result_type_bug_1874(): + py_sc = True + np_sc = np.asarray([py_sc])[0] + dts_bool = [py_sc, np_sc] + py_sc = int(1) + np_sc = np.asarray([py_sc])[0] + dts_ints = [py_sc, np_sc] + dts_floats = [float(1), np.float64(1)] + dts_complexes = [complex(1), np.complex128(1)] + + # iterate over two categories + for dts1, dts2 in itertools.product( + [dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2 + ): + res_dts = [] + # iterate over Python scalar/NumPy scalar choices within categories + for dt1, dt2 in itertools.product(dts1, dts2): + res_dt = dpt.result_type(dt1, dt2) + res_dts.append(res_dt) + # check that all results are the same + assert res_dts and all(res_dts[0] == el for el in res_dts[1:])