@@ -125,13 +125,6 @@ def params_type(self):
125125 input_ndim = int64 ,
126126 )
127127
128- @property
129- def _new_order (self ):
130- # Param for C code.
131- # self.new_order may contain 'x', which is not a valid integer value.
132- # We replace it with -1.
133- return [(- 1 if x == "x" else x ) for x in self .new_order ]
134-
135128 def __init__ (self , * , input_ndim : int , new_order : Sequence [int | Literal ["x" ]]):
136129 super ().__init__ ([self .c_func_file ], self .c_func_name )
137130
@@ -140,6 +133,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
140133
141134 self .input_ndim = input_ndim
142135 self .new_order = tuple (new_order )
136+ self ._new_order = [(- 1 if x == "x" else x ) for x in self .new_order ]
143137
144138 for i , j in enumerate (new_order ):
145139 if j != "x" :
@@ -231,22 +225,15 @@ def __str__(self):
231225 return f"DimShuffle{{order=[{ ',' .join (map (str , self .new_order ))} ]}}"
232226
233227 def perform (self , node , inp , out ):
234- (res ,) = inp
235- (storage ,) = out
236-
237- if not isinstance (res , np .ndarray | np .memmap ):
238- raise TypeError (res )
239-
240- # Put dropped axis at end
241- res = res .transpose (self .transposition )
242-
243- # Define new shape without dropped axis and including new ones
244- new_shape = list (res .shape [: len (self .shuffle )])
245- for augm in self .augment :
246- new_shape .insert (augm , 1 )
247- res = res .reshape (new_shape )
248-
249- storage [0 ] = np .asarray (res )
228+ (inp ,) = inp
229+ new_order = self ._new_order
230+ old_shape = inp .shape
231+ old_strides = inp .strides
232+
233+ res = inp .view ()
234+ res .shape = [1 if i == - 1 else old_shape [i ] for i in new_order ]
235+ res .strides = [0 if i == - 1 else old_strides [i ] for i in new_order ]
236+ out [0 ][0 ] = res
250237
251238 def infer_shape (self , fgraph , node , shapes ):
252239 (ishp ,) = shapes
0 commit comments