@@ -166,15 +166,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
166166 self .transposition = self .shuffle + drop
167167 # List of dimensions of the output that are broadcastable and were not
168168 # in the original input
169- self .augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
169+ self .augment = augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
170170 self .drop = drop
171171
172- self .is_left_expand_dims = self .augment and (
172+ dims_are_shuffled = sorted (self .shuffle ) != self .shuffle
173+
174+ self .is_transpose = dims_are_shuffled and not augment and not drop
175+ self .is_squeeze = drop and not dims_are_shuffled and not augment
176+ self .is_expand_dims = augment and not dims_are_shuffled and not drop
177+ self .is_left_expand_dims = self .is_expand_dims and (
173178 input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
174179 )
175- self .is_right_expand_dims = self .augment and new_order [: input_ndim ] == list (
176- range ( input_ndim )
177- )
180+ self .is_right_expand_dims = self .is_expand_dims and new_order [
181+ : input_ndim
182+ ] == list ( range ( input_ndim ) )
178183
179184 if self .inplace :
180185 self .view_map = {0 : [0 ]}
@@ -215,16 +220,15 @@ def make_node(self, inp):
215220 return Apply (self , [input ], [output ])
216221
217222 def __str__ (self ):
218- shuffle = sorted (self .shuffle ) != self .shuffle
219- if self .augment and not (shuffle or self .drop ):
223+ if self .is_expand_dims :
220224 if len (self .augment ) == 1 :
221225 return f"ExpandDims{{axis={ self .augment [0 ]} }}"
222226 return f"ExpandDims{{axes={ self .augment } }}"
223- if self .drop and not ( self . augment or shuffle ) :
227+ if self .is_squeeze :
224228 if len (self .drop ) == 1 :
225- return f"DropDims {{axis={ self .drop [0 ]} }}"
226- return f"DropDims {{axes={ self .drop } }}"
227- if shuffle and not ( self .augment or self . drop ) :
229+ return f"Squeeze {{axis={ self .drop [0 ]} }}"
230+ return f"Squeeze {{axes={ self .drop } }}"
231+ if self .is_transpose :
228232 return f"Transpose{{axes={ self .shuffle } }}"
229233 return f"DimShuffle{{order=[{ ',' .join (map (str , self .new_order ))} ]}}"
230234
0 commit comments