Skip to content

Commit 4b252e4

Browse files
committed
Int32 is too small for geomspace test
1 parent 8f79b16 commit 4b252e4

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/tensor/test_extra_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,11 +1301,13 @@ def test_broadcast_arrays():
13011301
["linspace", "logspace", "geomspace"],
13021302
ids=["linspace", "logspace", "geomspace"],
13031303
)
1304-
@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
1304+
@pytest.mark.parametrize(
1305+
"dtype", [None, "int64", "floatX"], ids=[None, "int64", "floatX"]
1306+
)
13051307
@pytest.mark.parametrize(
13061308
"start, stop, num_samples, endpoint, axis",
13071309
[
1308-
(1, 10, 50, True, 0),
1310+
(1, 10, 40, True, 0),
13091311
(1, 10, 1, True, 0),
13101312
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0),
13111313
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1),
@@ -1317,7 +1319,7 @@ def test_broadcast_arrays():
13171319
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
13181320
pt_func = getattr(pt, op)
13191321
np_func = getattr(np, op)
1320-
dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
1322+
dtype = dtype if dtype != "floatX" else config.floatX
13211323
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)
13221324

13231325
numpy_res = np_func(

0 commit comments

Comments
 (0)