Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dpctl/tensor/_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_all_data_types,
_can_cast,
_is_weak_dtype,
_result_type_fn_impl,
_strong_dtype_num_kind,
_to_device_supported_dtype,
_weak_type_num_kind,
Expand Down Expand Up @@ -95,7 +96,13 @@ def _resolve_two_weak_types(o1_dtype, o2_dtype, dev):


def _where_result_type(dt1, dt2, dev):
res_dtype = dpt.result_type(dt1, dt2)
res_dtype = _result_type_fn_impl(
(
dt1,
dt2,
),
sycl_device=dev,
)
fp16 = dev.has_aspect_fp16
fp64 = dev.has_aspect_fp64

Expand Down
58 changes: 33 additions & 25 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
import dpctl.tensor._tensor_impl as ti


def _supported_dtype(dtypes):
for dtype in dtypes:
if dtype.char not in "?bBhHiIlLqQefdFD":
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
return True


def _all_data_types(_fp16, _fp64):
_non_fp_types = [
dpt.bool,
Expand Down Expand Up @@ -708,26 +715,11 @@ def can_cast(from_, to, /, *, casting="safe") -> bool:
return _can_cast(dtype_from, dtype_to, True, True, casting=casting)


def result_type(*arrays_and_dtypes):
"""
result_type(*arrays_and_dtypes)

Returns the dtype that results from applying the Type Promotion Rules to \
the arguments.

Args:
arrays_and_dtypes (Union[usm_ndarray, dtype]):
An arbitrary length sequence of usm_ndarray objects or dtypes.

Returns:
dtype:
The dtype resulting from an operation involving the
input arrays and dtypes.
"""
def _result_type_fn_impl(arrays_and_dtypes_tuple, sycl_device=None):
dtypes = []
devices = []
devices = [] if sycl_device is None else [sycl_device]
weak_dtypes = []
for arg_i in arrays_and_dtypes:
for arg_i in arrays_and_dtypes_tuple:
if isinstance(arg_i, dpt.usm_ndarray):
devices.append(arg_i.sycl_device)
dtypes.append(arg_i.dtype)
Expand Down Expand Up @@ -766,6 +758,10 @@ def result_type(*arrays_and_dtypes):
has_fp64 = d.has_aspect_fp64
target_dev = d
inspected = True
else:
raise ValueError(
"At least only argument must have type `dpctl.tensor.usm_ndarray`"
)

if not (has_fp16 and has_fp64):
for dt in dtypes:
Expand All @@ -788,6 +784,25 @@ def result_type(*arrays_and_dtypes):
return res_dt


def result_type(*arrays_and_dtypes):
"""
result_type(*arrays_and_dtypes)

Returns the dtype that results from applying the Type Promotion Rules to \
the arguments.

Args:
arrays_and_dtypes (Union[usm_ndarray, dtype]):
An arbitrary length sequence of usm_ndarray objects or dtypes.

Returns:
dtype:
The dtype resulting from an operation involving the
input arrays and dtypes.
"""
return _result_type_fn_impl(arrays_and_dtypes)


def iinfo(dtype, /):
"""iinfo(dtype)

Expand Down Expand Up @@ -855,13 +870,6 @@ def finfo(dtype, /):
return finfo_object(dtype)


def _supported_dtype(dtypes):
for dtype in dtypes:
if dtype.char not in "?bBhHiIlLqQefdFD":
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
return True


def isdtype(dtype, kind):
"""isdtype(dtype, kind)

Expand Down
9 changes: 6 additions & 3 deletions dpctl/tests/helper/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ def skip_if_dtype_not_supported(dt, q_or_dev):
import dpctl.tensor as dpt

dt = dpt.dtype(dt)
if type(q_or_dev) is dpctl.SyclQueue:
if isinstance(q_or_dev, dpctl.SyclQueue):
dev = q_or_dev.sycl_device
elif type(q_or_dev) is dpctl.SyclDevice:
dev = q_or_dev
else:
dev = q_or_dev

if not hasattr(dev, "has_aspect_fp16") or not hasattr(
dev, "has_aspect_fp64"
):
raise TypeError(
"Expected dpctl.SyclQueue or dpctl.SyclDevice, "
f"got {type(q_or_dev)}"
Expand Down
9 changes: 4 additions & 5 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,6 @@ def test_result_type():

assert dpt.result_type(*X) == np.result_type(*X_np)

X = [dpt.int32, "int64", 2]
X_np = [np.int32, "int64", 2]

assert dpt.result_type(*X) == np.result_type(*X_np)

X = [usm_ar, dpt.int32, "int64", 2.0]
X_np = [np_ar, np.int32, "int64", 2.0]

Expand All @@ -1013,6 +1008,10 @@ def test_result_type():

assert dpt.result_type(*X).kind == np.result_type(*X_np).kind

X = [dpt.int32, "int64", 2]
with pytest.raises(ValueError):
dpt.result_type(*X)


def test_swapaxes_1d():
get_queue_or_skip()
Expand Down
8 changes: 6 additions & 2 deletions dpctl/tests/test_usm_ndarray_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

class mock_device:
def __init__(self, fp16, fp64):
self.name = "Mock device"
self.has_aspect_fp16 = fp16
self.has_aspect_fp64 = fp64

Expand Down Expand Up @@ -101,14 +102,17 @@ def test_where_result_types(dt1, dt2, fp16, fp64):
dev = mock_device(fp16, fp64)

dt1 = dpt.dtype(dt1)
skip_if_dtype_not_supported(dt1, dev)
dt2 = dpt.dtype(dt2)
skip_if_dtype_not_supported(dt2, dev)

res_t = _where_result_type(dt1, dt2, dev)

if fp16 and fp64:
assert res_t == dpt.result_type(dt1, dt2)
assert res_t == np.result_type(dt1, dt2)
else:
if res_t:
assert res_t.kind == dpt.result_type(dt1, dt2).kind
assert res_t.kind == np.result_type(dt1, dt2).kind
else:
# some illegal cases are covered above, but
# this guarantees that _where_result_type
Expand Down
Loading