1
+ from collections .abc import Sequence
1
2
from copy import copy
2
3
from textwrap import dedent
4
+ from typing import Literal
3
5
4
6
import numpy as np
5
7
from numpy .core .numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
54
56
55
57
Parameters
56
58
----------
57
- input_broadcastable
58
- The expected broadcastable pattern of the input
59
+ input_ndim
60
+ The expected number of dimension of the input
59
61
new_order
60
62
A list representing the relationship between the input's
61
63
dimensions and the output's dimensions. Each element of the
62
64
list can either be an index or 'x'. Indices must be encoded
63
65
as python integers, not pytensor symbolic integers.
64
- inplace : bool, optional
65
- If True (default), the output will be a view of the input.
66
+ Missing indexes correspond to drop dimensions.
66
67
67
68
Notes
68
69
-----
@@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
77
78
78
79
.. code-block:: python
79
80
80
- DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
81
+ DimShuffle(input_ndim=3, new_order= ['x', 2, 'x', 0, 1])
81
82
82
- This `Op` will only work on 3d tensors with no broadcastable
83
- dimensions. The first dimension will be broadcastable,
83
+ This `Op` will only work on 3d tensors.
84
+ The first dimension of the output will be broadcastable,
84
85
then we will have the third dimension of the input tensor as
85
86
the second of the resulting tensor, etc. If the tensor has
86
87
shape (20, 30, 40), the resulting tensor will have dimensions
87
88
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
88
89
89
90
.. code-block:: python
90
91
91
- DimShuffle((True, False), [1])
92
+ DimShuffle(input_ndim=2, new_order= [1])
92
93
93
- This `Op` will only work on 2d tensors with the first dimension
94
- broadcastable.
95
- The second dimension of the input tensor will be the first dimension of
96
- the resulting tensor.
97
- If the tensor has shape (1, 20), the resulting tensor will have shape
98
- (20, ).
94
+ This `Op` will only work on 2d tensors with the first dimension broadcastable.
95
+ The second dimension of the input tensor will be the first dimension of the resulting tensor.
96
+ If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
99
97
100
98
Examples
101
99
--------
102
100
.. code-block:: python
103
101
104
- DimShuffle((), ['x']) # make a 0d (scalar) into a 1d vector
105
- DimShuffle((False, False), [0, 1]) # identity
106
- DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107
- DimShuffle((False,), ['x', 0]) # make a row out of a 1d vector
108
- # (N to 1xN)
109
- DimShuffle((False,), [0, 'x']) # make a column out of a 1d vector
110
- # (N to Nx1)
111
- DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112
- DimShuffle((False, False), [0, 'x', 1]) # AxB to Ax1xB
113
- DimShuffle((False, False), [1, 'x', 0]) # AxB to Bx1xA
114
-
115
- The reordering of the dimensions can be done with the numpy.transpose
116
- function.
117
- Adding, subtracting dimensions can be done with reshape.
102
+ DimShuffle(input_ndim=0, new_order=['x']) # make a 0d (scalar) into a 1d vector
103
+ DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104
+ DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105
+ DimShuffle(input_ndim=1, new_order=['x', 0]) # make a row out of a 1d vector (N to 1xN)
106
+ DimShuffle(input_ndim=1, new_order=[0, 'x']) # make a column out of a 1d vector (N to Nx1)
107
+ DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
108
+ DimShuffle(input_ndim=2, new_order=[0, 'x', 1]) # AxB to Ax1xB
109
+ DimShuffle(input_ndim=2, new_order=[1, 'x', 0]) # AxB to Bx1xA
118
110
111
+ Notes
112
+ -----
113
+ The python implementation of this Op combines numpy.transpose for reordering of the dimensions
114
+ and numpy.reshape for subtracting and adding broadcastable dimensions.
119
115
"""
120
116
121
117
_f16_ok = True
122
118
check_input = False
123
- __props__ = ("input_broadcastable " , "new_order" , "inplace" )
119
+ __props__ = ("input_ndim " , "new_order" , "inplace" )
124
120
c_func_file = "c_code/dimshuffle.c"
125
121
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126
122
@@ -133,16 +129,11 @@ def params_type(self):
133
129
inplace = scalar_bool ,
134
130
)
135
131
136
- def __init__ (self , input_broadcastable , new_order ):
132
+ def __init__ (self , * , input_ndim : int , new_order : Sequence [ int | Literal [ "x" ]] ):
137
133
super ().__init__ ([self .c_func_file ], self .c_func_name )
138
134
139
- self .input_broadcastable = tuple (input_broadcastable )
140
- if not all (isinstance (bs , bool | np .bool_ ) for bs in self .input_broadcastable ):
141
- raise ValueError (
142
- f"input_broadcastable must be boolean, { self .input_broadcastable } "
143
- )
135
+ self .input_ndim = input_ndim
144
136
self .new_order = tuple (new_order )
145
-
146
137
self .inplace = True
147
138
148
139
for i , j in enumerate (new_order ):
@@ -152,10 +143,10 @@ def __init__(self, input_broadcastable, new_order):
152
143
"DimShuffle indices must be Python ints; got "
153
144
f"{ j } of type { type (j )} ."
154
145
)
155
- if j >= len ( input_broadcastable ) :
146
+ if j >= input_ndim :
156
147
raise ValueError (
157
148
f"new_order[{ i } ] is { j } , but the input only has "
158
- f"{ len ( input_broadcastable ) } axes."
149
+ f"{ input_ndim } axes."
159
150
)
160
151
if j in new_order [(i + 1 ) :]:
161
152
raise ValueError (
@@ -164,19 +155,7 @@ def __init__(self, input_broadcastable, new_order):
164
155
)
165
156
166
157
# List of input dimensions to drop
167
- drop = []
168
- for i , b in enumerate (input_broadcastable ):
169
- if i not in new_order :
170
- # We want to drop this dimension because it's not a value in
171
- # `new_order`
172
- if b == 1 :
173
- drop .append (i )
174
- else :
175
- # We cannot drop non-broadcastable dimensions
176
- raise ValueError (
177
- "Cannot drop a non-broadcastable dimension: "
178
- f"{ input_broadcastable } , { new_order } "
179
- )
158
+ drop = [i for i in range (input_ndim ) if i not in new_order ]
180
159
181
160
# This is the list of the original dimensions that we keep
182
161
self .shuffle = [x for x in new_order if x != "x" ]
@@ -186,7 +165,6 @@ def __init__(self, input_broadcastable, new_order):
186
165
self .augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
187
166
self .drop = drop
188
167
189
- input_ndim = len (input_broadcastable )
190
168
self .is_left_expand_dims = self .augment and (
191
169
input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
192
170
)
@@ -204,30 +182,29 @@ def __setstate__(self, state):
204
182
# Let's just build the ExternalCOp.
205
183
super ().__init__ ([self .c_func_file ], self .c_func_name )
206
184
207
- def make_node (self , _input ):
208
- input = as_tensor_variable (_input )
209
- ib = tuple (s == 1 for s in input .type .shape )
210
- if ib != self .input_broadcastable :
211
- if len (ib ) != len (self .input_broadcastable ):
185
+ def make_node (self , inp ):
186
+ input = as_tensor_variable (inp )
187
+ if input .type .ndim != self .input_ndim :
188
+ raise TypeError (
189
+ "The number of dimensions of the input is incorrect for this op. "
190
+ f"Expected { self .input_ndim } , got { input .type .ndim } ."
191
+ )
192
+
193
+ input_static_shape = input .type .shape
194
+
195
+ # Runtime check for invalid drop
196
+ for d in self .drop :
197
+ if input_static_shape [d ] not in (1 , None ):
212
198
raise TypeError (
213
- "The number of dimensions of the "
214
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
199
+ f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
215
200
)
216
- for expected , b in zip (self .input_broadcastable , ib ):
217
- if expected and not b :
218
- raise TypeError (
219
- "The broadcastable pattern of the "
220
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
221
- )
222
- # else, expected == b or not expected and b
223
- # Both case are good.
224
201
225
202
out_static_shape = []
226
203
for dim_idx in self .new_order :
227
204
if dim_idx == "x" :
228
205
out_static_shape .append (1 )
229
206
else :
230
- out_static_shape .append (input . type . shape [dim_idx ])
207
+ out_static_shape .append (input_static_shape [dim_idx ])
231
208
232
209
output = TensorType (dtype = input .type .dtype , shape = out_static_shape )()
233
210
@@ -254,12 +231,14 @@ def perform(self, node, inp, out):
254
231
if not isinstance (res , np .ndarray | np .memmap ):
255
232
raise TypeError (res )
256
233
234
+ # Put dropped axis at end
257
235
res = res .transpose (self .transposition )
258
236
259
- shape = list (res .shape [: len (self .shuffle )])
237
+ # Define new shape without dropped axis and including new ones
238
+ new_shape = list (res .shape [: len (self .shuffle )])
260
239
for augm in self .augment :
261
- shape .insert (augm , 1 )
262
- res = res .reshape (shape )
240
+ new_shape .insert (augm , 1 )
241
+ res = res .reshape (new_shape )
263
242
264
243
if not self .inplace :
265
244
res = np .copy (res )
@@ -284,22 +263,15 @@ def R_op(self, inputs, eval_points):
284
263
def grad (self , inp , grads ):
285
264
(x ,) = inp
286
265
(gz ,) = grads
287
- gz = as_tensor_variable (gz )
288
266
grad_order = ["x" ] * x .type .ndim
289
267
for i , v in enumerate (self .new_order ):
290
268
if v != "x" :
291
269
grad_order [v ] = i
292
- # Do not make the DimShuffle inplace as an optimization at the
293
- # canonicalization optimization phase will remove the inplace.
294
- # The inplace will be reintroduced automatically later in the graph.
295
- if inp [0 ].dtype in discrete_dtypes :
296
- return [inp [0 ].zeros_like (dtype = config .floatX )]
270
+
271
+ if x .type .dtype in discrete_dtypes :
272
+ return [x .zeros_like (dtype = config .floatX )]
297
273
else :
298
- return [
299
- DimShuffle (tuple (s == 1 for s in gz .type .shape ), grad_order )(
300
- Elemwise (scalar_identity )(gz )
301
- )
302
- ]
274
+ return [gz .dimshuffle (grad_order )]
303
275
304
276
305
277
class DimShufflePrinter (Printer ):
@@ -409,7 +381,7 @@ def __setstate__(self, d):
409
381
self .nfunc = None
410
382
self .inplace_pattern = frozendict (self .inplace_pattern )
411
383
412
- def get_output_info (self , dim_shuffle , * inputs ):
384
+ def get_output_info (self , * inputs ):
413
385
"""Return the outputs dtype and broadcastable pattern and the
414
386
dimshuffled inputs.
415
387
@@ -427,12 +399,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427
399
if not difference :
428
400
args .append (input )
429
401
else :
430
- args .append (
431
- dim_shuffle (
432
- input .type .broadcastable ,
433
- ["x" ] * difference + list (range (length )),
434
- )(input )
435
- )
402
+ args .append (input .dimshuffle (["x" ] * difference + list (range (length ))))
436
403
inputs = args
437
404
438
405
# HERE: all the broadcast dims have the same length now
@@ -489,7 +456,7 @@ def make_node(self, *inputs):
489
456
using DimShuffle.
490
457
"""
491
458
inputs = [as_tensor_variable (i ) for i in inputs ]
492
- out_dtypes , out_shapes , inputs = self .get_output_info (DimShuffle , * inputs )
459
+ out_dtypes , out_shapes , inputs = self .get_output_info (* inputs )
493
460
outputs = [
494
461
TensorType (dtype = dtype , shape = shape )()
495
462
for dtype , shape in zip (out_dtypes , out_shapes )
@@ -634,7 +601,7 @@ def transform(r):
634
601
res = pytensor .tensor .basic .constant (
635
602
np .asarray (r .data ), dtype = r .type .dtype
636
603
)
637
- return DimShuffle ((), ["x" ] * nd )( res )
604
+ return res . dimshuffle ( ["x" ] * nd )
638
605
639
606
new_r = Elemwise (node .op , {})(* [transform (ipt ) for ipt in node .inputs ])
640
607
if isinstance (new_r , list | tuple ):
@@ -1707,13 +1674,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
1707
1674
batched_ndims = x .type .ndim - node .inputs [0 ].type .ndim
1708
1675
if not batched_ndims :
1709
1676
return node .op .make_node (x )
1710
- input_broadcastable = x .type .broadcastable [:batched_ndims ] + op .input_broadcastable
1711
- # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712
- # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1677
+ # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1678
+ # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
1713
1679
new_order = list (range (batched_ndims )) + [
1714
1680
"x" if (o == "x" ) else (o + batched_ndims ) for o in op .new_order
1715
1681
]
1716
- return DimShuffle ( input_broadcastable , new_order ).make_node ( x )
1682
+ return x . dimshuffle ( new_order ).owner
1717
1683
1718
1684
1719
1685
def get_normalized_batch_axes (
0 commit comments