Skip to content

Commit 6e59293

Browse files
Constructor usm_ndarray raises if dtype is not native for device
The default data type can hence raise.
1 parent 8273054 commit 6e59293

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ cdef class usm_ndarray:
178178
cdef Py_ssize_t _offset = offset
179179
cdef Py_ssize_t ary_min_displacement = 0
180180
cdef Py_ssize_t ary_max_displacement = 0
181+
cdef bint is_fp64 = False
182+
cdef bint is_fp16 = False
181183

182184
self._reset()
183185
if (not isinstance(shape, (list, tuple))
@@ -253,6 +255,16 @@ cdef class usm_ndarray:
253255
self._cleanup()
254256
raise ValueError(("buffer='{}' can not accomodate "
255257
"the requested array.").format(buffer))
258+
is_fp64 = (typenum == UAR_DOUBLE or typenum == UAR_CDOUBLE)
259+
is_fp16 = (typenum == UAR_HALF)
260+
if (is_fp64 or is_fp16):
261+
if ((is_fp64 and not _buffer.sycl_device.has_aspect_fp64) or
262+
(is_fp16 and not _buffer.sycl_device.has_aspect_fp16)
263+
):
264+
raise ValueError(
265+
f"Device {_buffer.sycl_device.name} does"
266+
f" not support {dtype} natively."
267+
)
256268
self.base_ = _buffer
257269
self.data_ = (<char *> (<size_t> _buffer._pointer)) + itemsize * _offset
258270
self.shape_ = shape_ptr

0 commit comments

Comments
 (0)