Skip to content

Commit 8273054

Browse files
Fix gh-1038 by adding checks to dpt.empty, dpt.zeros
Both functions will now raise ValueError is data type not natively supported by device is requested.
1 parent 1916370 commit 8273054

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,35 @@ def _is_object_with_buffer_protocol(obj):
283283
return False
284284

285285

286+
def _ensure_native_dtype_device_support(dtype, dev) -> None:
287+
"""Check that dtype is natively supported by device.
288+
289+
Arg:
290+
dtype: elemental data-type
291+
dev: :class:`dpctl.SyclDevice`
292+
Return:
293+
None
294+
Raise:
295+
ValueError is device does not natively support this dtype.
296+
"""
297+
if dtype in [dpt.float64, dpt.complex128] and not dev.has_aspect_fp64:
298+
raise ValueError(
299+
f"Device {dev.name} does not provide native support "
300+
"for double-precision floating point type."
301+
)
302+
if (
303+
dtype
304+
in [
305+
dpt.float16,
306+
]
307+
and not dev.has_aspect_fp16
308+
):
309+
raise ValueError(
310+
f"Device {dev.name} does not provide native support "
311+
"for half-precision floating point type."
312+
)
313+
314+
286315
def asarray(
287316
obj,
288317
dtype=None,
@@ -474,6 +503,7 @@ def empty(
474503
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
475504
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
476505
dtype = _get_dtype(dtype, sycl_queue)
506+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
477507
res = dpt.usm_ndarray(
478508
sh,
479509
dtype=dtype,
@@ -651,6 +681,7 @@ def zeros(
651681
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
652682
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
653683
dtype = _get_dtype(dtype, sycl_queue)
684+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
654685
res = dpt.usm_ndarray(
655686
sh,
656687
dtype=dtype,

0 commit comments

Comments
 (0)