Skip to content

Commit 4312d8c

Browse files
CopilotricardoV94jessegrabowski
authored
Fix pt.flip to handle negative axis correctly using normalize_axis_tuple (#1628)
* Initial plan * Fix pt.flip to handle negative axis correctly using normalize_axis_tuple Co-authored-by: ricardoV94 <[email protected]> * Expand existing test_flip function to include negative axis tests Co-authored-by: ricardoV94 <[email protected]> * Fix mypy error by using separate variable for normalized axis Co-authored-by: ricardoV94 <[email protected]> * Fix ruff formatting issues Co-authored-by: jessegrabowski <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ricardoV94 <[email protected]> Co-authored-by: jessegrabowski <[email protected]>
1 parent b05acfd commit 4312d8c

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

pytensor/tensor/subtensor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.graph.utils import MethodNotDefined
1919
from pytensor.link.c.op import COp
2020
from pytensor.link.c.params_type import ParamsType
21-
from pytensor.npy_2_compat import numpy_version, using_numpy_2
21+
from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2
2222
from pytensor.printing import Printer, pprint, set_precedence
2323
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
2424
from pytensor.tensor import (
@@ -3369,11 +3369,12 @@ def flip(
33693369
if axis is None:
33703370
index = ((slice(None, None, -1)),) * arr.ndim
33713371
else:
3372-
if isinstance(axis, int):
3373-
axis = (axis,)
3372+
normalized_axis = normalize_axis_tuple(axis, arr.ndim)
33743373
index = tuple(
33753374
[
3376-
slice(None, None, -1) if i in axis else slice(None, None, None)
3375+
slice(None, None, -1)
3376+
if i in normalized_axis
3377+
else slice(None, None, None)
33773378
for i in range(arr.ndim)
33783379
]
33793380
)
@@ -3382,9 +3383,9 @@ def flip(
33823383

33833384

33843385
__all__ = [
3385-
"take",
33863386
"flip",
3387-
"slice_at_axis",
33883387
"inc_subtensor",
33893388
"set_subtensor",
3389+
"slice_at_axis",
3390+
"take",
33903391
]

tests/tensor/test_subtensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,6 +3147,27 @@ def test_flip(size: tuple[int]):
31473147
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
31483148
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
31493149

3150+
# Test single negative axis
3151+
for axis in range(-x.ndim, 0):
3152+
expected = np.flip(x, axis=axis)
3153+
z = flip(x_pt, axis=axis)
3154+
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
3155+
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
3156+
3157+
# Test tuple with negative axes
3158+
if x.ndim > 1:
3159+
expected = np.flip(x, axis=(-1, -2))
3160+
z = flip(x_pt, axis=(-1, -2))
3161+
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
3162+
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
3163+
3164+
# Test mixed positive and negative axes
3165+
if x.ndim >= 2:
3166+
expected = np.flip(x, axis=(0, -1))
3167+
z = flip(x_pt, axis=(0, -1))
3168+
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
3169+
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
3170+
31503171

31513172
class TestBenchmarks:
31523173
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)