Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,12 @@ def test_join_negative_axis_rewrite(self):
# Create join with negative axis
s = join(-1, a, b)

assert s.owner.op.axis == 1
# Get the actual Join op node from the graph
f = pytensor.function([], [s], mode=self.mode)

# Directly access the Join node from the output's owner
join_node = f.maker.fgraph.outputs[0].owner
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"

# Check that the axis input has been converted to a constant with value 1 (not -1)
axis_input = join_node.inputs[0]
Expand All @@ -2201,11 +2206,8 @@ def test_join_negative_axis_rewrite(self):
s2 = join(-2, a, b)
f2 = pytensor.function([], [s2], mode=self.mode)

join_nodes = [
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Join)
]
assert len(join_nodes) == 1, "Expected exactly one Join node in the graph"
join_node = join_nodes[0]
join_node = f2.maker.fgraph.outputs[0].owner
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"

# Check that the axis input has been converted to a constant with value 0 (not -2)
axis_input = join_node.inputs[0]
Expand Down