Skip to content

Commit cab399d

Browse files
Fixed a bug in dpt.arange
A bug would manifest itself if the starting value is outside of range of the array data type. Fixed it by coercing the starting value to the data type, like is currently done for the increment. The test was modified to cover such an input.
1 parent 2522187 commit cab399d

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,8 @@ def arange(
547547
)
548548
_step = (start + step) - start
549549
_step = dt.type(_step)
550-
hev, _ = ti._linspace_step(start, _step, res, sycl_queue)
550+
_start = dt.type(start)
551+
hev, _ = ti._linspace_step(_start, _step, res, sycl_queue)
551552
hev.wait()
552553
return res
553554

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -985,11 +985,13 @@ def test_arange(dt):
985985
elif np.issubdtype(dt, np.complexfloating):
986986
assert complex(X[47]) == 47.0 + 0.0j
987987

988-
X1 = dpt.arange(4, dtype=dt, sycl_queue=q)
989-
assert X1.shape == (4,)
988+
# choose size larger than maximal value that u1/u2 can accomodate
989+
sz = int(np.iinfo(np.int16).max) + 1
990+
X1 = dpt.arange(sz, dtype=dt, sycl_queue=q)
991+
assert X1.shape == (sz,)
990992

991-
X2 = dpt.arange(4, 0, -1, dtype=dt, sycl_queue=q)
992-
assert X2.shape == (4,)
993+
X2 = dpt.arange(sz, 0, -1, dtype=dt, sycl_queue=q)
994+
assert X2.shape == (sz,)
993995

994996

995997
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)