We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6cd90ee commit ae0132fCopy full SHA for ae0132f
tests/tensor/test_extra_ops.py
@@ -1301,7 +1301,9 @@ def test_broadcast_arrays():
1301
["linspace", "logspace", "geomspace"],
1302
ids=["linspace", "logspace", "geomspace"],
1303
)
1304
-@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
+@pytest.mark.parametrize(
1305
+ "dtype", [None, "int64", "floatX"], ids=[None, "int64", "floatX"]
1306
+)
1307
@pytest.mark.parametrize(
1308
"start, stop, num_samples, endpoint, axis",
1309
[
@@ -1317,7 +1319,7 @@ def test_broadcast_arrays():
1317
1319
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
1318
1320
pt_func = getattr(pt, op)
1321
np_func = getattr(np, op)
- dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
1322
+ dtype = dtype if dtype != "floatX" else config.floatX
1323
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)
1324
1325
numpy_res = np_func(
0 commit comments