Skip to content

Commit 9e18d3c

Browse files
committed
Cache sub-type of DimShuffle
1 parent ebb06ff commit 9e18d3c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

pytensor/tensor/elemwise.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
166166
self.transposition = self.shuffle + drop
167167
# List of dimensions of the output that are broadcastable and were not
168168
# in the original input
169-
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
169+
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
170170
self.drop = drop
171171

172-
self.is_left_expand_dims = self.augment and (
172+
dims_are_shuffled = sorted(self.shuffle) != self.shuffle
173+
174+
self.is_transpose = dims_are_shuffled and not augment and not drop
175+
self.is_squeeze = drop and not dims_are_shuffled and not augment
176+
self.is_expand_dims = augment and not dims_are_shuffled and not drop
177+
self.is_left_expand_dims = self.is_expand_dims and (
173178
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
174179
)
175-
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
176-
range(input_ndim)
177-
)
180+
self.is_right_expand_dims = self.is_expand_dims and new_order[
181+
:input_ndim
182+
] == list(range(input_ndim))
178183

179184
if self.inplace:
180185
self.view_map = {0: [0]}
@@ -215,16 +220,15 @@ def make_node(self, inp):
215220
return Apply(self, [input], [output])
216221

217222
def __str__(self):
218-
shuffle = sorted(self.shuffle) != self.shuffle
219-
if self.augment and not (shuffle or self.drop):
223+
if self.is_expand_dims:
220224
if len(self.augment) == 1:
221225
return f"ExpandDims{{axis={self.augment[0]}}}"
222226
return f"ExpandDims{{axes={self.augment}}}"
223-
if self.drop and not (self.augment or shuffle):
227+
if self.is_squeeze:
224228
if len(self.drop) == 1:
225-
return f"DropDims{{axis={self.drop[0]}}}"
226-
return f"DropDims{{axes={self.drop}}}"
227-
if shuffle and not (self.augment or self.drop):
229+
return f"Squeeze{{axis={self.drop[0]}}}"
230+
return f"Squeeze{{axes={self.drop}}}"
231+
if self.is_transpose:
228232
return f"Transpose{{axes={self.shuffle}}}"
229233
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
230234

0 commit comments

Comments
 (0)