@@ -283,6 +283,35 @@ def _is_object_with_buffer_protocol(obj):
283
283
return False
284
284
285
285
286
+ def _ensure_native_dtype_device_support (dtype , dev ) -> None :
287
+ """Check that dtype is natively supported by device.
288
+
289
+ Arg:
290
+ dtype: elemental data-type
291
+ dev: :class:`dpctl.SyclDevice`
292
+ Return:
293
+ None
294
+ Raise:
295
+ ValueError is device does not natively support this dtype.
296
+ """
297
+ if dtype in [dpt .float64 , dpt .complex128 ] and not dev .has_aspect_fp64 :
298
+ raise ValueError (
299
+ f"Device { dev .name } does not provide native support "
300
+ "for double-precision floating point type."
301
+ )
302
+ if (
303
+ dtype
304
+ in [
305
+ dpt .float16 ,
306
+ ]
307
+ and not dev .has_aspect_fp16
308
+ ):
309
+ raise ValueError (
310
+ f"Device { dev .name } does not provide native support "
311
+ "for half-precision floating point type."
312
+ )
313
+
314
+
286
315
def asarray (
287
316
obj ,
288
317
dtype = None ,
@@ -474,6 +503,7 @@ def empty(
474
503
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
475
504
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
476
505
dtype = _get_dtype (dtype , sycl_queue )
506
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
477
507
res = dpt .usm_ndarray (
478
508
sh ,
479
509
dtype = dtype ,
@@ -651,6 +681,7 @@ def zeros(
651
681
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
652
682
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
653
683
dtype = _get_dtype (dtype , sycl_queue )
684
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
654
685
res = dpt .usm_ndarray (
655
686
sh ,
656
687
dtype = dtype ,
0 commit comments