Skip to content

Commit 0b23a66

Browse files
committed
Test without compiling functions
1 parent ccd812d commit 0b23a66

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

tests/tensor/test_basic.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,33 +2188,31 @@ def test_join_negative_axis_rewrite(self):
21882188
# Create join with negative axis
21892189
s = join(-1, a, b)
21902190

2191-
# Get the actual Join op node from the graph
2192-
f = pytensor.function([], [s], mode=self.mode)
2191+
assert isinstance(
2192+
s.owner.outputs[0].owner.op, Join
2193+
), "Expected output node to be a Join op"
21932194

2194-
# Directly access the Join node from the output's owner
2195-
join_node = f.maker.fgraph.outputs[0].owner
2196-
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"
2197-
2198-
# Check that the axis input has been converted to a constant with value 1 (not -1)
2199-
axis_input = join_node.inputs[0]
2200-
assert isinstance(axis_input, ptb.Constant), "Expected axis to be a Constant"
2195+
assert isinstance(
2196+
s.owner.inputs[0], ptb.Constant
2197+
), "Expected axis to be a Constant"
22012198
assert (
2202-
axis_input.data == 1
2203-
), f"Expected axis to be normalized to 1, got {axis_input.data}"
2199+
s.owner.inputs[0].data == 1
2200+
), f"Expected axis to be normalized to 1, got {s.owner.inputs[0].data}"
22042201

22052202
# Now test with axis -2 which should be rewritten to 0
22062203
s2 = join(-2, a, b)
2207-
f2 = pytensor.function([], [s2], mode=self.mode)
22082204

2209-
join_node = f2.maker.fgraph.outputs[0].owner
2210-
assert isinstance(join_node.op, Join), "Expected output node to be a Join op"
2205+
assert isinstance(
2206+
s2.owner.outputs[0].owner.op, Join
2207+
), "Expected output node to be a Join op"
22112208

22122209
# Check that the axis input has been converted to a constant with value 0 (not -2)
2213-
axis_input = join_node.inputs[0]
2214-
assert isinstance(axis_input, ptb.Constant), "Expected axis to be a Constant"
2210+
assert isinstance(
2211+
s2.owner.inputs[0], ptb.Constant
2212+
), "Expected axis to be a Constant"
22152213
assert (
2216-
axis_input.data == 0
2217-
), f"Expected axis to be normalized to 0, got {axis_input.data}"
2214+
s2.owner.inputs[0].data == 0
2215+
), f"Expected axis to be normalized to 0, got {s2.owner.inputs[0].data}"
22182216

22192217

22202218
def test_TensorFromScalar():

0 commit comments

Comments
 (0)