Skip to content

Commit 8fc2a84

Browse files
Fixed arange(3.) to not produce float64 array on HW with fp64 support
1 parent da1bf68 commit 8fc2a84

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,12 @@ def empty(
453453
return res
454454

455455

456-
def _coerce_and_infer_dt(*args, dt):
456+
def _coerce_and_infer_dt(*args, dt, sycl_queue):
457457
"Deduce arange type from sequence spec"
458458
nd, seq_dt, d = _array_info_sequence(args)
459459
if d != _host_set or nd != (len(args),):
460460
raise ValueError("start, stop and step must be Python scalars")
461-
if dt is None:
462-
dt = seq_dt
463-
dt = np.dtype(dt)
461+
dt = _get_dtype(dt, sycl_queue, ref_type=seq_dt)
464462
if np.issubdtype(dt, np.integer):
465463
return tuple(int(v) for v in args), dt
466464
elif np.issubdtype(dt, np.floating):
@@ -526,20 +524,22 @@ def arange(
526524
if stop is None:
527525
stop = start
528526
start = 0
527+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
528+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
529529
(
530530
start,
531531
stop,
532532
step,
533-
), dt = _coerce_and_infer_dt(start, stop, step, dt=dtype)
533+
), dt = _coerce_and_infer_dt(
534+
start, stop, step, dt=dtype, sycl_queue=sycl_queue
535+
)
534536
try:
535537
tmp = _get_arange_length(start, stop, step)
536538
sh = int(tmp)
537539
if sh < 0:
538540
sh = 0
539541
except TypeError:
540542
sh = 0
541-
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
542-
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
543543
res = dpt.usm_ndarray(
544544
(sh,),
545545
dtype=dt,

0 commit comments

Comments
 (0)