@@ -2188,7 +2188,12 @@ def test_join_negative_axis_rewrite(self):
21882188 # Create join with negative axis
21892189 s = join (- 1 , a , b )
21902190
2191- assert s .owner .op .axis == 1
2191+ # Get the actual Join op node from the graph
2192+ f = pytensor .function ([], [s ], mode = self .mode )
2193+
2194+ # Directly access the Join node from the output's owner
2195+ join_node = f .maker .fgraph .outputs [0 ].owner
2196+ assert isinstance (join_node .op , Join ), "Expected output node to be a Join op"
21922197
21932198 # Check that the axis input has been converted to a constant with value 1 (not -1)
21942199 axis_input = join_node .inputs [0 ]
@@ -2201,11 +2206,8 @@ def test_join_negative_axis_rewrite(self):
22012206 s2 = join (- 2 , a , b )
22022207 f2 = pytensor .function ([], [s2 ], mode = self .mode )
22032208
2204- join_nodes = [
2205- node for node in f2 .maker .fgraph .toposort () if isinstance (node .op , Join )
2206- ]
2207- assert len (join_nodes ) == 1 , "Expected exactly one Join node in the graph"
2208- join_node = join_nodes [0 ]
2209+ join_node = f2 .maker .fgraph .outputs [0 ].owner
2210+ assert isinstance (join_node .op , Join ), "Expected output node to be a Join op"
22092211
22102212 # Check that the axis input has been converted to a constant with value 0 (not -2)
22112213 axis_input = join_node .inputs [0 ]
0 commit comments