Skip to content

Commit 8e1cd56

Browse files
committed
Avoid no-op DimShuffle
1 parent f72d7e5 commit 8e1cd56

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

pytensor/tensor/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
349349
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
350350
pattern = pattern[0]
351351
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
352+
if ds_op.new_order == tuple(range(self.type.ndim)):
353+
# No-op
354+
return self
352355
return ds_op(self)
353356

354357
def flatten(self, ndim=1):

tests/tensor/random/rewriting/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def test_Dimshuffle_lift_restrictions():
950950
1e-7,
951951
),
952952
(
953-
(0, 1, 2),
953+
(0, 2, 1),
954954
True,
955955
normal,
956956
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),

tests/tensor/rewriting/test_elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_recursive_lift(self):
148148

149149
def test_useless_dimshuffle(self):
150150
x, *_ = inputs()
151-
e = ds(x, (0, 1))
151+
e = DimShuffle(new_order=(0, 1), input_ndim=2)(x)
152152
g = FunctionGraph([x], [e], clone=False)
153153
assert isinstance(g.outputs[0].owner.op, DimShuffle)
154154
dimshuffle_lift.rewrite(g)

0 commit comments

Comments
 (0)