Skip to content

Commit 68d8dc7

Browse files
authored
Converts negative constant axis to positive if present in Join(COp) (#1527)
1 parent 4ce092f commit 68d8dc7

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,18 @@ def make_node(self, axis, *tensors):
24702470
if axis.type.ndim > 0:
24712471
raise TypeError(f"Axis {axis} must be 0-d.")
24722472

2473+
# Convert negative constant axis to positive during canonicalization
2474+
if isinstance(axis, Constant) and tensors:
2475+
# Get the axis value directly from the constant's data
2476+
axis_val = axis.data.item()
2477+
# Check if it's negative and needs normalization
2478+
if axis_val < 0:
2479+
ndim = tensors[0].ndim
2480+
# Convert negative axis to positive
2481+
axis_val = normalize_axis_index(axis_val, ndim)
2482+
# Replace the original axis with the normalized one
2483+
axis = constant(axis_val, dtype=axis.type.dtype)
2484+
24732485
tensors = [as_tensor_variable(x) for x in tensors]
24742486

24752487
if not builtins.all(targs.type.ndim > 0 for targs in tensors):

tests/tensor/test_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,6 +2179,15 @@ def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
21792179
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
21802180
benchmark(fn, *test_values)
21812181

2182+
def test_join_negative_axis_rewrite(self):
2183+
"""Test that constant negative axis is rewritten to positive axis in make_node."""
2184+
v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX)
2185+
a = self.shared(v)
2186+
b = as_tensor_variable(v)
2187+
2188+
assert equal_computations([join(-1, a, b)], [join(1, a, b)])
2189+
assert equal_computations([join(-2, a, b)], [join(0, a, b)])
2190+
21822191

21832192
def test_TensorFromScalar():
21842193
s = ps.constant(56)

0 commit comments

Comments
 (0)