Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866)

### Fixed
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877)

### Maintenance

Expand Down
3 changes: 3 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes):
target_dev = d
inspected = True

if not dtypes and weak_dtypes:
dtypes.append(weak_dtypes[0].get())

if not (has_fp16 and has_fp64):
for dt in dtypes:
if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
Expand Down
21 changes: 21 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.


import itertools

import numpy as np
import pytest
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
Expand Down Expand Up @@ -1555,3 +1557,22 @@ def test_repeat_0_size():
res = dpt.repeat(x, repetitions, axis=1)
axis_sz = 2 * x.shape[1]
assert res.shape == (0, axis_sz, 0)


def test_result_type_bug_1874():
dts_bool = [True, np.bool_(True)]
dts_ints = [int(1), np.int64(1)]
dts_floats = [float(1), np.float64(1)]
dts_complexes = [complex(1), np.complex128(1)]

# iterate over two categories
for dts1, dts2 in itertools.product(
[dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2
):
res_dts = []
# iterate over Python scalar/NumPy scalar choices within categories
for dt1, dt2 in itertools.product(dts1, dts2):
res_dt = dpt.result_type(dt1, dt2)
res_dts.append(res_dt)
# check that all results are the same
assert res_dts and all(res_dts[0] == el for el in res_dts[1:])
Loading