Skip to content

Commit 4c9d413

Browse files
Fixed error in arange shape computation
dpt.arange(4, 0, -1) was producing array of wrong size. Bug fixed, test to be added.
1 parent ae86ab5 commit 4c9d413

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

dpctl/tensor/_ctors.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def empty(
454454

455455

456456
def _coerce_and_infer_dt(*args, dt):
457+
"Deduce arange type from sequence spec"
457458
nd, seq_dt, d = _array_info_sequence(args)
458459
if d != _host_set or nd != (len(args),):
459460
raise ValueError("start, stop and step must be Python scalars")
@@ -470,6 +471,24 @@ def _coerce_and_infer_dt(*args, dt):
470471
raise ValueError(f"Data type {dt} is not supported")
471472

472473

474+
def _get_arange_length(start, stop, step):
475+
"Compute length of arange sequence"
476+
span = stop - start
477+
if type(step) in [int, float] and type(span) in [int, float]:
478+
offset = -1 if step > 0 else 1
479+
tmp = 1 + (span + offset) / step
480+
return tmp
481+
tmp = span / step
482+
if type(tmp) is complex and tmp.imag == 0:
483+
tmp = tmp.real
484+
else:
485+
return tmp
486+
k = int(tmp)
487+
if k > 0 and float(k) < tmp:
488+
tmp = tmp + 1
489+
return tmp
490+
491+
473492
def arange(
474493
start,
475494
/,
@@ -511,9 +530,7 @@ def arange(
511530
step,
512531
), dt = _coerce_and_infer_dt(start, stop, step, dt=dtype)
513532
try:
514-
tmp = 1 + (stop - start - 1) / step
515-
if type(tmp) is complex and tmp.imag == 0.0:
516-
tmp = tmp.real
533+
tmp = _get_arange_length(start, stop, step)
517534
sh = int(tmp)
518535
if sh < 0:
519536
sh = 0
@@ -529,6 +546,7 @@ def arange(
529546
buffer_ctor_kwargs={"queue": sycl_queue},
530547
)
531548
_step = (start + step) - start
549+
_step = dt.type(_step)
532550
hev, _ = ti._linspace_step(start, _step, res, sycl_queue)
533551
hev.wait()
534552
return res

0 commit comments

Comments
 (0)