Skip to content

Commit 3cc8e68

Browse files
committed
Move iota tests to test_basic.py
1 parent 3d93067 commit 3cc8e68

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

tests/tensor/test_basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
identity_like,
5959
infer_static_shape,
6060
inverse_permutation,
61+
iota,
6162
join,
6263
make_vector,
6364
mgrid,
@@ -980,6 +981,29 @@ def test_static_output_type(self):
980981
assert eye(1, l, 3).type.shape == (1, None)
981982

982983

984+
def test_iota():
985+
mode = Mode(linker="py", optimizer=None)
986+
np.testing.assert_allclose(
987+
iota((4, 8), 0).eval(mode=mode),
988+
[
989+
[0, 0, 0, 0, 0, 0, 0, 0],
990+
[1, 1, 1, 1, 1, 1, 1, 1],
991+
[2, 2, 2, 2, 2, 2, 2, 2],
992+
[3, 3, 3, 3, 3, 3, 3, 3],
993+
],
994+
)
995+
996+
np.testing.assert_allclose(
997+
iota((4, 8), 1).eval(mode=mode),
998+
[
999+
[0, 1, 2, 3, 4, 5, 6, 7],
1000+
[0, 1, 2, 3, 4, 5, 6, 7],
1001+
[0, 1, 2, 3, 4, 5, 6, 7],
1002+
[0, 1, 2, 3, 4, 5, 6, 7],
1003+
],
1004+
)
1005+
1006+
9831007
class TestTriangle:
9841008
def test_tri(self):
9851009
def check(dtype, N, M_=None, k=0):

tests/tensor/test_einsum.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.graph.op import HasInnerGraph
1111
from pytensor.tensor.basic import moveaxis
1212
from pytensor.tensor.blockwise import Blockwise
13-
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
13+
from pytensor.tensor.einsum import _delta, _general_dot, einsum
1414
from pytensor.tensor.shape import Reshape
1515
from pytensor.tensor.type import tensor
1616

@@ -38,29 +38,6 @@ def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
3838
assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op)
3939

4040

41-
def test_iota():
42-
mode = Mode(linker="py", optimizer=None)
43-
np.testing.assert_allclose(
44-
_iota((4, 8), 0).eval(mode=mode),
45-
[
46-
[0, 0, 0, 0, 0, 0, 0, 0],
47-
[1, 1, 1, 1, 1, 1, 1, 1],
48-
[2, 2, 2, 2, 2, 2, 2, 2],
49-
[3, 3, 3, 3, 3, 3, 3, 3],
50-
],
51-
)
52-
53-
np.testing.assert_allclose(
54-
_iota((4, 8), 1).eval(mode=mode),
55-
[
56-
[0, 1, 2, 3, 4, 5, 6, 7],
57-
[0, 1, 2, 3, 4, 5, 6, 7],
58-
[0, 1, 2, 3, 4, 5, 6, 7],
59-
[0, 1, 2, 3, 4, 5, 6, 7],
60-
],
61-
)
62-
63-
6441
def test_delta():
6542
mode = Mode(linker="py", optimizer=None)
6643
np.testing.assert_allclose(

0 commit comments

Comments
 (0)