-
Notifications
You must be signed in to change notification settings - Fork 145
Converts negative constant axis to positive if present in Join(COp)
#1527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a1f99bf
a8de454
ccd812d
0b23a66
8657a64
26071f7
1071abc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2179,6 +2179,41 @@ def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark): | |
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6) | ||
benchmark(fn, *test_values) | ||
|
||
def test_join_negative_axis_rewrite(self): | ||
"""Test that constant negative axis is rewritten to positive axis during canonicalization.""" | ||
v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) | ||
a = self.shared(v) | ||
b = as_tensor_variable(v) | ||
|
||
# Create join with negative axis | ||
s = join(-1, a, b) | ||
|
||
assert isinstance( | ||
s.owner.outputs[0].owner.op, Join | ||
), "Expected output node to be a Join op" | ||
|
||
assert isinstance( | ||
s.owner.inputs[0], ptb.Constant | ||
), "Expected axis to be a Constant" | ||
assert ( | ||
s.owner.inputs[0].data == 1 | ||
), f"Expected axis to be normalized to 1, got {s.owner.inputs[0].data}" | ||
|
||
|
||
# Now test with axis -2 which should be rewritten to 0 | ||
s2 = join(-2, a, b) | ||
|
||
assert isinstance( | ||
s2.owner.outputs[0].owner.op, Join | ||
), "Expected output node to be a Join op" | ||
|
||
# Check that the axis input has been converted to a constant with value 0 (not -2) | ||
assert isinstance( | ||
s2.owner.inputs[0], ptb.Constant | ||
), "Expected axis to be a Constant" | ||
assert ( | ||
s2.owner.inputs[0].data == 0 | ||
), f"Expected axis to be normalized to 0, got {s2.owner.inputs[0].data}" | ||
|
||
|
||
def test_TensorFromScalar(): | ||
s = ps.constant(56) | ||
|
Uh oh!
There was an error while loading. Please reload this page.