@@ -453,14 +453,12 @@ def empty(
453
453
return res
454
454
455
455
456
- def _coerce_and_infer_dt (* args , dt ):
456
+ def _coerce_and_infer_dt (* args , dt , sycl_queue ):
457
457
"Deduce arange type from sequence spec"
458
458
nd , seq_dt , d = _array_info_sequence (args )
459
459
if d != _host_set or nd != (len (args ),):
460
460
raise ValueError ("start, stop and step must be Python scalars" )
461
- if dt is None :
462
- dt = seq_dt
463
- dt = np .dtype (dt )
461
+ dt = _get_dtype (dt , sycl_queue , ref_type = seq_dt )
464
462
if np .issubdtype (dt , np .integer ):
465
463
return tuple (int (v ) for v in args ), dt
466
464
elif np .issubdtype (dt , np .floating ):
@@ -526,20 +524,22 @@ def arange(
526
524
if stop is None :
527
525
stop = start
528
526
start = 0
527
+ dpctl .utils .validate_usm_type (usm_type , allow_none = False )
528
+ sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
529
529
(
530
530
start ,
531
531
stop ,
532
532
step ,
533
- ), dt = _coerce_and_infer_dt (start , stop , step , dt = dtype )
533
+ ), dt = _coerce_and_infer_dt (
534
+ start , stop , step , dt = dtype , sycl_queue = sycl_queue
535
+ )
534
536
try :
535
537
tmp = _get_arange_length (start , stop , step )
536
538
sh = int (tmp )
537
539
if sh < 0 :
538
540
sh = 0
539
541
except TypeError :
540
542
sh = 0
541
- dpctl .utils .validate_usm_type (usm_type , allow_none = False )
542
- sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
543
543
res = dpt .usm_ndarray (
544
544
(sh ,),
545
545
dtype = dt ,
0 commit comments