Skip to content

Commit c855a6d

Browse files
committed
Avoid creating useless squeezes and expand_dims
1 parent f8c0c4d commit c855a6d

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,10 @@ def squeeze(x, axis=None):
603603
except np.AxisError:
604604
raise np.AxisError(axis, ndim=_x.ndim)
605605

606+
if not axis:
607+
# Nothing to do
608+
return _x
609+
606610
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
607611

608612

pytensor/tensor/shape.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,8 @@ def shape_padleft(t, n_ones=1):
868868
869869
"""
870870
_t = at.as_tensor_variable(t)
871-
871+
if n_ones == 0:
872+
return _t
872873
pattern = ["x"] * n_ones + list(range(_t.type.ndim))
873874
return _t.dimshuffle(pattern)
874875

@@ -884,7 +885,8 @@ def shape_padright(t, n_ones=1):
884885
885886
"""
886887
_t = at.as_tensor_variable(t)
887-
888+
if n_ones == 0:
889+
return _t
888890
pattern = list(range(_t.type.ndim)) + ["x"] * n_ones
889891
return _t.dimshuffle(pattern)
890892

0 commit comments

Comments
 (0)