@@ -2185,34 +2185,8 @@ def test_join_negative_axis_rewrite(self):
21852185 a = self .shared (v )
21862186 b = as_tensor_variable (v )
21872187
2188- # Create join with negative axis
2189- s = join (- 1 , a , b )
2190-
2191- assert isinstance (
2192- s .owner .outputs [0 ].owner .op , Join
2193- ), "Expected output node to be a Join op"
2194-
2195- assert isinstance (
2196- s .owner .inputs [0 ], ptb .Constant
2197- ), "Expected axis to be a Constant"
2198- assert (
2199- s .owner .inputs [0 ].data == 1
2200- ), f"Expected axis to be normalized to 1, got { s .owner .inputs [0 ].data } "
2201-
2202- # Now test with axis -2 which should be rewritten to 0
2203- s2 = join (- 2 , a , b )
2204-
2205- assert isinstance (
2206- s2 .owner .outputs [0 ].owner .op , Join
2207- ), "Expected output node to be a Join op"
2208-
2209- # Check that the axis input has been converted to a constant with value 0 (not -2)
2210- assert isinstance (
2211- s2 .owner .inputs [0 ], ptb .Constant
2212- ), "Expected axis to be a Constant"
2213- assert (
2214- s2 .owner .inputs [0 ].data == 0
2215- ), f"Expected axis to be normalized to 0, got { s2 .owner .inputs [0 ].data } "
2188+ equal_computations ([join (- 1 , a , b )], [join (1 , a , b )])
2189+ equal_computations ([join (- 2 , a , b )], [join (0 , a , b )])
22162190
22172191
22182192def test_TensorFromScalar ():
0 commit comments