Skip to content

Commit 26071f7

Browse files
committed
Test using equal_computations
1 parent 8657a64 commit 26071f7

File tree

1 file changed

+2
-28
lines changed

1 file changed

+2
-28
lines changed

tests/tensor/test_basic.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,34 +2185,8 @@ def test_join_negative_axis_rewrite(self):
21852185
a = self.shared(v)
21862186
b = as_tensor_variable(v)
21872187

2188-
# Create join with negative axis
2189-
s = join(-1, a, b)
2190-
2191-
assert isinstance(
2192-
s.owner.outputs[0].owner.op, Join
2193-
), "Expected output node to be a Join op"
2194-
2195-
assert isinstance(
2196-
s.owner.inputs[0], ptb.Constant
2197-
), "Expected axis to be a Constant"
2198-
assert (
2199-
s.owner.inputs[0].data == 1
2200-
), f"Expected axis to be normalized to 1, got {s.owner.inputs[0].data}"
2201-
2202-
# Now test with axis -2 which should be rewritten to 0
2203-
s2 = join(-2, a, b)
2204-
2205-
assert isinstance(
2206-
s2.owner.outputs[0].owner.op, Join
2207-
), "Expected output node to be a Join op"
2208-
2209-
# Check that the axis input has been converted to a constant with value 0 (not -2)
2210-
assert isinstance(
2211-
s2.owner.inputs[0], ptb.Constant
2212-
), "Expected axis to be a Constant"
2213-
assert (
2214-
s2.owner.inputs[0].data == 0
2215-
), f"Expected axis to be normalized to 0, got {s2.owner.inputs[0].data}"
2188+
equal_computations([join(-1, a, b)], [join(1, a, b)])
2189+
equal_computations([join(-2, a, b)], [join(0, a, b)])
22162190

22172191

22182192
def test_TensorFromScalar():

0 commit comments

Comments
 (0)