1
+ from collections .abc import Sequence
1
2
from copy import copy
2
3
from textwrap import dedent
4
+ from typing import Literal
3
5
4
6
import numpy as np
5
7
from numpy .core .numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
54
56
55
57
Parameters
56
58
----------
57
- input_broadcastable
58
- The expected broadcastable pattern of the input
59
+ input_ndim
60
+ The expected number of dimension of the input
59
61
new_order
60
62
A list representing the relationship between the input's
61
63
dimensions and the output's dimensions. Each element of the
62
64
list can either be an index or 'x'. Indices must be encoded
63
65
as python integers, not pytensor symbolic integers.
64
- inplace : bool, optional
65
- If True (default), the output will be a view of the input.
66
+ Missing indexes correspond to drop dimensions.
66
67
67
68
Notes
68
69
-----
@@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
77
78
78
79
.. code-block:: python
79
80
80
- DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
81
+ DimShuffle(input_ndim=3, new_order= ["x", 2, "x", 0, 1])
81
82
82
- This `Op` will only work on 3d tensors with no broadcastable
83
- dimensions. The first dimension will be broadcastable,
83
+ This `Op` will only work on 3d tensors.
84
+ The first dimension of the output will be broadcastable,
84
85
then we will have the third dimension of the input tensor as
85
86
the second of the resulting tensor, etc. If the tensor has
86
87
shape (20, 30, 40), the resulting tensor will have dimensions
87
88
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
88
89
89
90
.. code-block:: python
90
91
91
- DimShuffle((True, False), [1])
92
+ DimShuffle(input_ndim=2, new_order= [1])
92
93
93
- This `Op` will only work on 2d tensors with the first dimension
94
- broadcastable.
95
- The second dimension of the input tensor will be the first dimension of
96
- the resulting tensor.
97
- If the tensor has shape (1, 20), the resulting tensor will have shape
98
- (20, ).
94
+ This `Op` will only work on 2d tensors with the first dimension broadcastable.
95
+ The second dimension of the input tensor will be the first dimension of the resulting tensor.
96
+ If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
99
97
100
98
Examples
101
99
--------
102
100
.. code-block:: python
103
101
104
- DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
105
- DimShuffle((False, False), [0, 1]) # identity
106
- DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107
- DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
108
- # (N to 1xN)
109
- DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
110
- # (N to Nx1)
111
- DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112
- DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
113
- DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
114
-
115
- The reordering of the dimensions can be done with the numpy.transpose
116
- function.
117
- Adding, subtracting dimensions can be done with reshape.
102
+ DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
103
+ DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104
+ DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105
+ DimShuffle(input_ndim=1, new_order=["x", 0]) # make a row out of a 1d vector (N to 1xN)
106
+ DimShuffle(input_ndim=1, new_order=[0, "x"]) # make a column out of a 1d vector (N to Nx1)
107
+ DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
108
+ DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
109
+ DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
118
110
111
+ Notes
112
+ -----
113
+ The python implementation of this Op combines numpy.transpose for reordering of the dimensions
114
+ and numpy.reshape for subtracting and adding broadcastable dimensions.
119
115
"""
120
116
121
117
_f16_ok = True
122
118
check_input = False
123
- __props__ = ("input_broadcastable " , "new_order" , "inplace" )
119
+ __props__ = ("input_ndim " , "new_order" , "inplace" )
124
120
c_func_file = "c_code/dimshuffle.c"
125
121
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126
122
@@ -133,16 +129,14 @@ def params_type(self):
133
129
inplace = scalar_bool ,
134
130
)
135
131
136
- def __init__ (self , input_broadcastable , new_order ):
132
+ def __init__ (self , * , input_ndim : int , new_order : Sequence [ int | Literal [ "x" ]] ):
137
133
super ().__init__ ([self .c_func_file ], self .c_func_name )
138
134
139
- self .input_broadcastable = tuple (input_broadcastable )
140
- if not all (isinstance (bs , bool | np .bool_ ) for bs in self .input_broadcastable ):
141
- raise ValueError (
142
- f"input_broadcastable must be boolean, { self .input_broadcastable } "
143
- )
144
- self .new_order = tuple (new_order )
135
+ if not isinstance (input_ndim , int ):
136
+ raise TypeError (f"input_ndim must be an integer, got { type (int )} " )
145
137
138
+ self .input_ndim = input_ndim
139
+ self .new_order = tuple (new_order )
146
140
self .inplace = True
147
141
148
142
for i , j in enumerate (new_order ):
@@ -152,10 +146,10 @@ def __init__(self, input_broadcastable, new_order):
152
146
"DimShuffle indices must be Python ints; got "
153
147
f"{ j } of type { type (j )} ."
154
148
)
155
- if j >= len ( input_broadcastable ) :
149
+ if j >= input_ndim :
156
150
raise ValueError (
157
151
f"new_order[{ i } ] is { j } , but the input only has "
158
- f"{ len ( input_broadcastable ) } axes."
152
+ f"{ input_ndim } axes."
159
153
)
160
154
if j in new_order [(i + 1 ) :]:
161
155
raise ValueError (
@@ -164,19 +158,7 @@ def __init__(self, input_broadcastable, new_order):
164
158
)
165
159
166
160
# List of input dimensions to drop
167
- drop = []
168
- for i , b in enumerate (input_broadcastable ):
169
- if i not in new_order :
170
- # We want to drop this dimension because it's not a value in
171
- # `new_order`
172
- if b == 1 :
173
- drop .append (i )
174
- else :
175
- # We cannot drop non-broadcastable dimensions
176
- raise ValueError (
177
- "Cannot drop a non-broadcastable dimension: "
178
- f"{ input_broadcastable } , { new_order } "
179
- )
161
+ drop = [i for i in range (input_ndim ) if i not in new_order ]
180
162
181
163
# This is the list of the original dimensions that we keep
182
164
self .shuffle = [x for x in new_order if x != "x" ]
@@ -186,7 +168,6 @@ def __init__(self, input_broadcastable, new_order):
186
168
self .augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
187
169
self .drop = drop
188
170
189
- input_ndim = len (input_broadcastable )
190
171
self .is_left_expand_dims = self .augment and (
191
172
input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
192
173
)
@@ -204,30 +185,29 @@ def __setstate__(self, state):
204
185
# Let's just build the ExternalCOp.
205
186
super ().__init__ ([self .c_func_file ], self .c_func_name )
206
187
207
- def make_node (self , _input ):
208
- input = as_tensor_variable (_input )
209
- ib = tuple (s == 1 for s in input .type .shape )
210
- if ib != self .input_broadcastable :
211
- if len (ib ) != len (self .input_broadcastable ):
188
+ def make_node (self , inp ):
189
+ input = as_tensor_variable (inp )
190
+ if input .type .ndim != self .input_ndim :
191
+ raise TypeError (
192
+ "The number of dimensions of the input is incorrect for this op. "
193
+ f"Expected { self .input_ndim } , got { input .type .ndim } ."
194
+ )
195
+
196
+ input_static_shape = input .type .shape
197
+
198
+ # Runtime check for invalid drop
199
+ for d in self .drop :
200
+ if input_static_shape [d ] not in (1 , None ):
212
201
raise TypeError (
213
- "The number of dimensions of the "
214
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
202
+ f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
215
203
)
216
- for expected , b in zip (self .input_broadcastable , ib ):
217
- if expected and not b :
218
- raise TypeError (
219
- "The broadcastable pattern of the "
220
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
221
- )
222
- # else, expected == b or not expected and b
223
- # Both case are good.
224
204
225
205
out_static_shape = []
226
206
for dim_idx in self .new_order :
227
207
if dim_idx == "x" :
228
208
out_static_shape .append (1 )
229
209
else :
230
- out_static_shape .append (input . type . shape [dim_idx ])
210
+ out_static_shape .append (input_static_shape [dim_idx ])
231
211
232
212
output = TensorType (dtype = input .type .dtype , shape = out_static_shape )()
233
213
@@ -254,12 +234,14 @@ def perform(self, node, inp, out):
254
234
if not isinstance (res , np .ndarray | np .memmap ):
255
235
raise TypeError (res )
256
236
237
+ # Put dropped axis at end
257
238
res = res .transpose (self .transposition )
258
239
259
- shape = list (res .shape [: len (self .shuffle )])
240
+ # Define new shape without dropped axis and including new ones
241
+ new_shape = list (res .shape [: len (self .shuffle )])
260
242
for augm in self .augment :
261
- shape .insert (augm , 1 )
262
- res = res .reshape (shape )
243
+ new_shape .insert (augm , 1 )
244
+ res = res .reshape (new_shape )
263
245
264
246
if not self .inplace :
265
247
res = np .copy (res )
@@ -284,22 +266,15 @@ def R_op(self, inputs, eval_points):
284
266
def grad (self , inp , grads ):
285
267
(x ,) = inp
286
268
(gz ,) = grads
287
- gz = as_tensor_variable (gz )
288
269
grad_order = ["x" ] * x .type .ndim
289
270
for i , v in enumerate (self .new_order ):
290
271
if v != "x" :
291
272
grad_order [v ] = i
292
- # Do not make the DimShuffle inplace as an optimization at the
293
- # canonicalization optimization phase will remove the inplace.
294
- # The inplace will be reintroduced automatically later in the graph.
295
- if inp [0 ].dtype in discrete_dtypes :
296
- return [inp [0 ].zeros_like (dtype = config .floatX )]
273
+
274
+ if x .type .dtype in discrete_dtypes :
275
+ return [x .zeros_like (dtype = config .floatX )]
297
276
else :
298
- return [
299
- DimShuffle (tuple (s == 1 for s in gz .type .shape ), grad_order )(
300
- Elemwise (scalar_identity )(gz )
301
- )
302
- ]
277
+ return [gz .dimshuffle (grad_order )]
303
278
304
279
305
280
class DimShufflePrinter (Printer ):
@@ -409,7 +384,7 @@ def __setstate__(self, d):
409
384
self .nfunc = None
410
385
self .inplace_pattern = frozendict (self .inplace_pattern )
411
386
412
- def get_output_info (self , dim_shuffle , * inputs ):
387
+ def get_output_info (self , * inputs ):
413
388
"""Return the outputs dtype and broadcastable pattern and the
414
389
dimshuffled inputs.
415
390
@@ -427,12 +402,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427
402
if not difference :
428
403
args .append (input )
429
404
else :
430
- args .append (
431
- dim_shuffle (
432
- input .type .broadcastable ,
433
- ["x" ] * difference + list (range (length )),
434
- )(input )
435
- )
405
+ args .append (input .dimshuffle (["x" ] * difference + list (range (length ))))
436
406
inputs = args
437
407
438
408
# HERE: all the broadcast dims have the same length now
@@ -489,7 +459,7 @@ def make_node(self, *inputs):
489
459
using DimShuffle.
490
460
"""
491
461
inputs = [as_tensor_variable (i ) for i in inputs ]
492
- out_dtypes , out_shapes , inputs = self .get_output_info (DimShuffle , * inputs )
462
+ out_dtypes , out_shapes , inputs = self .get_output_info (* inputs )
493
463
outputs = [
494
464
TensorType (dtype = dtype , shape = shape )()
495
465
for dtype , shape in zip (out_dtypes , out_shapes )
@@ -634,7 +604,7 @@ def transform(r):
634
604
res = pytensor .tensor .basic .constant (
635
605
np .asarray (r .data ), dtype = r .type .dtype
636
606
)
637
- return DimShuffle ((), ["x" ] * nd )( res )
607
+ return res . dimshuffle ( ["x" ] * nd )
638
608
639
609
new_r = Elemwise (node .op , {})(* [transform (ipt ) for ipt in node .inputs ])
640
610
if isinstance (new_r , list | tuple ):
@@ -1707,13 +1677,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
1707
1677
batched_ndims = x .type .ndim - node .inputs [0 ].type .ndim
1708
1678
if not batched_ndims :
1709
1679
return node .op .make_node (x )
1710
- input_broadcastable = x .type .broadcastable [:batched_ndims ] + op .input_broadcastable
1711
- # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712
- # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1680
+ # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1681
+ # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
1713
1682
new_order = list (range (batched_ndims )) + [
1714
1683
"x" if (o == "x" ) else (o + batched_ndims ) for o in op .new_order
1715
1684
]
1716
- return DimShuffle ( input_broadcastable , new_order ).make_node ( x )
1685
+ return x . dimshuffle ( new_order ).owner
1717
1686
1718
1687
1719
1688
def get_normalized_batch_axes (
0 commit comments