@@ -2188,33 +2188,31 @@ def test_join_negative_axis_rewrite(self):
21882188 # Create join with negative axis
21892189 s = join (- 1 , a , b )
21902190
2191- # Get the actual Join op node from the graph
2192- f = pytensor .function ([], [s ], mode = self .mode )
2191+ assert isinstance (
2192+ s .owner .outputs [0 ].owner .op , Join
2193+ ), "Expected output node to be a Join op"
21932194
2194- # Directly access the Join node from the output's owner
2195- join_node = f .maker .fgraph .outputs [0 ].owner
2196- assert isinstance (join_node .op , Join ), "Expected output node to be a Join op"
2197-
2198- # Check that the axis input has been converted to a constant with value 1 (not -1)
2199- axis_input = join_node .inputs [0 ]
2200- assert isinstance (axis_input , ptb .Constant ), "Expected axis to be a Constant"
2195+ assert isinstance (
2196+ s .owner .inputs [0 ], ptb .Constant
2197+ ), "Expected axis to be a Constant"
22012198 assert (
2202- axis_input .data == 1
2203- ), f"Expected axis to be normalized to 1, got { axis_input .data } "
2199+ s . owner . inputs [ 0 ] .data == 1
2200+ ), f"Expected axis to be normalized to 1, got { s . owner . inputs [ 0 ] .data } "
22042201
22052202 # Now test with axis -2 which should be rewritten to 0
22062203 s2 = join (- 2 , a , b )
2207- f2 = pytensor .function ([], [s2 ], mode = self .mode )
22082204
2209- join_node = f2 .maker .fgraph .outputs [0 ].owner
2210- assert isinstance (join_node .op , Join ), "Expected output node to be a Join op"
2205+ assert isinstance (
2206+ s2 .owner .outputs [0 ].owner .op , Join
2207+ ), "Expected output node to be a Join op"
22112208
22122209 # Check that the axis input has been converted to a constant with value 0 (not -2)
2213- axis_input = join_node .inputs [0 ]
2214- assert isinstance (axis_input , ptb .Constant ), "Expected axis to be a Constant"
2210+ assert isinstance (
2211+ s2 .owner .inputs [0 ], ptb .Constant
2212+ ), "Expected axis to be a Constant"
22152213 assert (
2216- axis_input .data == 0
2217- ), f"Expected axis to be normalized to 0, got { axis_input .data } "
2214+ s2 . owner . inputs [ 0 ] .data == 0
2215+ ), f"Expected axis to be normalized to 0, got { s2 . owner . inputs [ 0 ] .data } "
22182216
22192217
22202218def test_TensorFromScalar ():
0 commit comments