Skip to content

Commit ccd812d

Browse files
committed
Simplify test
1 parent a8de454 commit ccd812d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tests/tensor/test_basic.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,12 @@ def test_join_negative_axis_rewrite(self):
21882188
# Create join with negative axis
21892189
s = join(-1, a, b)
21902190

2191-
assert s.owner.op.axis == 1
2191+
# Get the actual Join op node from the graph
2192+
f = pytensor.function([], [s], mode=self.mode)
2193+
2194+
# Directly access the Join node from the output's owner
2195+
join_node = f.maker.fgraph.outputs[0].owner
2196+
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"
21922197

21932198
# Check that the axis input has been converted to a constant with value 1 (not -1)
21942199
axis_input = join_node.inputs[0]
@@ -2201,11 +2206,8 @@ def test_join_negative_axis_rewrite(self):
22012206
s2 = join(-2, a, b)
22022207
f2 = pytensor.function([], [s2], mode=self.mode)
22032208

2204-
join_nodes = [
2205-
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Join)
2206-
]
2207-
assert len(join_nodes) == 1, "Expected exactly one Join node in the graph"
2208-
join_node = join_nodes[0]
2209+
join_node = f2.maker.fgraph.outputs[0].owner
2210+
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"
22092211

22102212
# Check that the axis input has been converted to a constant with value 0 (not -2)
22112213
axis_input = join_node.inputs[0]

0 commit comments

Comments
 (0)