Skip to content

Commit 278faf6

Browse files
committed
Only require input_ndim and not input_broadcastable in DimShuffle
1 parent a9ed164 commit 278faf6

File tree

5 files changed

+73
-106
lines changed

5 files changed

+73
-106
lines changed

pytensor/tensor/elemwise.py

Lines changed: 61 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
from copy import copy
23
from textwrap import dedent
4+
from typing import Literal
35

46
import numpy as np
57
from numpy.core.numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
5456
5557
Parameters
5658
----------
57-
input_broadcastable
58-
The expected broadcastable pattern of the input
59+
input_ndim
60+
The expected number of dimension of the input
5961
new_order
6062
A list representing the relationship between the input's
6163
dimensions and the output's dimensions. Each element of the
6264
list can either be an index or 'x'. Indices must be encoded
6365
as python integers, not pytensor symbolic integers.
64-
inplace : bool, optional
65-
If True (default), the output will be a view of the input.
66+
Missing indexes correspond to drop dimensions.
6667
6768
Notes
6869
-----
@@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
7778
7879
.. code-block:: python
7980
80-
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
81+
DimShuffle(input_ndim=3, new_order=['x', 2, 'x', 0, 1])
8182
82-
This `Op` will only work on 3d tensors with no broadcastable
83-
dimensions. The first dimension will be broadcastable,
83+
This `Op` will only work on 3d tensors.
84+
The first dimension of the output will be broadcastable,
8485
then we will have the third dimension of the input tensor as
8586
the second of the resulting tensor, etc. If the tensor has
8687
shape (20, 30, 40), the resulting tensor will have dimensions
8788
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
8889
8990
.. code-block:: python
9091
91-
DimShuffle((True, False), [1])
92+
DimShuffle(input_ndim=2, new_order=[1])
9293
93-
This `Op` will only work on 2d tensors with the first dimension
94-
broadcastable.
95-
The second dimension of the input tensor will be the first dimension of
96-
the resulting tensor.
97-
If the tensor has shape (1, 20), the resulting tensor will have shape
98-
(20, ).
94+
This `Op` will only work on 2d tensors with the first dimension broadcastable.
95+
The second dimension of the input tensor will be the first dimension of the resulting tensor.
96+
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
9997
10098
Examples
10199
--------
102100
.. code-block:: python
103101
104-
DimShuffle((), ['x']) # make a 0d (scalar) into a 1d vector
105-
DimShuffle((False, False), [0, 1]) # identity
106-
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107-
DimShuffle((False,), ['x', 0]) # make a row out of a 1d vector
108-
# (N to 1xN)
109-
DimShuffle((False,), [0, 'x']) # make a column out of a 1d vector
110-
# (N to Nx1)
111-
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112-
DimShuffle((False, False), [0, 'x', 1]) # AxB to Ax1xB
113-
DimShuffle((False, False), [1, 'x', 0]) # AxB to Bx1xA
114-
115-
The reordering of the dimensions can be done with the numpy.transpose
116-
function.
117-
Adding, subtracting dimensions can be done with reshape.
102+
DimShuffle(input_ndim=0, new_order=['x']) # make a 0d (scalar) into a 1d vector
103+
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104+
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105+
DimShuffle(input_ndim=1, new_order=['x', 0]) # make a row out of a 1d vector (N to 1xN)
106+
DimShuffle(input_ndim=1, new_order=[0, 'x']) # make a column out of a 1d vector (N to Nx1)
107+
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
108+
DimShuffle(input_ndim=2, new_order=[0, 'x', 1]) # AxB to Ax1xB
109+
DimShuffle(input_ndim=2, new_order=[1, 'x', 0]) # AxB to Bx1xA
118110
111+
Notes
112+
-----
113+
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
114+
and numpy.reshape for subtracting and adding broadcastable dimensions.
119115
"""
120116

121117
_f16_ok = True
122118
check_input = False
123-
__props__ = ("input_broadcastable", "new_order", "inplace")
119+
__props__ = ("input_ndim", "new_order", "inplace")
124120
c_func_file = "c_code/dimshuffle.c"
125121
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126122

@@ -133,16 +129,11 @@ def params_type(self):
133129
inplace=scalar_bool,
134130
)
135131

136-
def __init__(self, input_broadcastable, new_order):
132+
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
137133
super().__init__([self.c_func_file], self.c_func_name)
138134

139-
self.input_broadcastable = tuple(input_broadcastable)
140-
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
141-
raise ValueError(
142-
f"input_broadcastable must be boolean, {self.input_broadcastable}"
143-
)
135+
self.input_ndim = input_ndim
144136
self.new_order = tuple(new_order)
145-
146137
self.inplace = True
147138

148139
for i, j in enumerate(new_order):
@@ -152,10 +143,10 @@ def __init__(self, input_broadcastable, new_order):
152143
"DimShuffle indices must be Python ints; got "
153144
f"{j} of type {type(j)}."
154145
)
155-
if j >= len(input_broadcastable):
146+
if j >= input_ndim:
156147
raise ValueError(
157148
f"new_order[{i}] is {j}, but the input only has "
158-
f"{len(input_broadcastable)} axes."
149+
f"{input_ndim} axes."
159150
)
160151
if j in new_order[(i + 1) :]:
161152
raise ValueError(
@@ -164,19 +155,7 @@ def __init__(self, input_broadcastable, new_order):
164155
)
165156

166157
# List of input dimensions to drop
167-
drop = []
168-
for i, b in enumerate(input_broadcastable):
169-
if i not in new_order:
170-
# We want to drop this dimension because it's not a value in
171-
# `new_order`
172-
if b == 1:
173-
drop.append(i)
174-
else:
175-
# We cannot drop non-broadcastable dimensions
176-
raise ValueError(
177-
"Cannot drop a non-broadcastable dimension: "
178-
f"{input_broadcastable}, {new_order}"
179-
)
158+
drop = [i for i in range(input_ndim) if i not in new_order]
180159

181160
# This is the list of the original dimensions that we keep
182161
self.shuffle = [x for x in new_order if x != "x"]
@@ -186,7 +165,6 @@ def __init__(self, input_broadcastable, new_order):
186165
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
187166
self.drop = drop
188167

189-
input_ndim = len(input_broadcastable)
190168
self.is_left_expand_dims = self.augment and (
191169
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
192170
)
@@ -204,30 +182,29 @@ def __setstate__(self, state):
204182
# Let's just build the ExternalCOp.
205183
super().__init__([self.c_func_file], self.c_func_name)
206184

207-
def make_node(self, _input):
208-
input = as_tensor_variable(_input)
209-
ib = tuple(s == 1 for s in input.type.shape)
210-
if ib != self.input_broadcastable:
211-
if len(ib) != len(self.input_broadcastable):
185+
def make_node(self, inp):
186+
input = as_tensor_variable(inp)
187+
if input.type.ndim != self.input_ndim:
188+
raise TypeError(
189+
"The number of dimensions of the input is incorrect for this op. "
190+
f"Expected {self.input_ndim}, got {input.type.ndim}."
191+
)
192+
193+
input_static_shape = input.type.shape
194+
195+
# Runtime check for invalid drop
196+
for d in self.drop:
197+
if input_static_shape[d] not in (1, None):
212198
raise TypeError(
213-
"The number of dimensions of the "
214-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
199+
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
215200
)
216-
for expected, b in zip(self.input_broadcastable, ib):
217-
if expected and not b:
218-
raise TypeError(
219-
"The broadcastable pattern of the "
220-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
221-
)
222-
# else, expected == b or not expected and b
223-
# Both case are good.
224201

225202
out_static_shape = []
226203
for dim_idx in self.new_order:
227204
if dim_idx == "x":
228205
out_static_shape.append(1)
229206
else:
230-
out_static_shape.append(input.type.shape[dim_idx])
207+
out_static_shape.append(input_static_shape[dim_idx])
231208

232209
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
233210

@@ -254,12 +231,14 @@ def perform(self, node, inp, out):
254231
if not isinstance(res, np.ndarray | np.memmap):
255232
raise TypeError(res)
256233

234+
# Put dropped axis at end
257235
res = res.transpose(self.transposition)
258236

259-
shape = list(res.shape[: len(self.shuffle)])
237+
# Define new shape without dropped axis and including new ones
238+
new_shape = list(res.shape[: len(self.shuffle)])
260239
for augm in self.augment:
261-
shape.insert(augm, 1)
262-
res = res.reshape(shape)
240+
new_shape.insert(augm, 1)
241+
res = res.reshape(new_shape)
263242

264243
if not self.inplace:
265244
res = np.copy(res)
@@ -284,22 +263,15 @@ def R_op(self, inputs, eval_points):
284263
def grad(self, inp, grads):
285264
(x,) = inp
286265
(gz,) = grads
287-
gz = as_tensor_variable(gz)
288266
grad_order = ["x"] * x.type.ndim
289267
for i, v in enumerate(self.new_order):
290268
if v != "x":
291269
grad_order[v] = i
292-
# Do not make the DimShuffle inplace as an optimization at the
293-
# canonicalization optimization phase will remove the inplace.
294-
# The inplace will be reintroduced automatically later in the graph.
295-
if inp[0].dtype in discrete_dtypes:
296-
return [inp[0].zeros_like(dtype=config.floatX)]
270+
271+
if x.type.dtype in discrete_dtypes:
272+
return [x.zeros_like(dtype=config.floatX)]
297273
else:
298-
return [
299-
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
300-
Elemwise(scalar_identity)(gz)
301-
)
302-
]
274+
return [gz.dimshuffle(grad_order)]
303275

304276

305277
class DimShufflePrinter(Printer):
@@ -409,7 +381,7 @@ def __setstate__(self, d):
409381
self.nfunc = None
410382
self.inplace_pattern = frozendict(self.inplace_pattern)
411383

412-
def get_output_info(self, dim_shuffle, *inputs):
384+
def get_output_info(self, *inputs):
413385
"""Return the outputs dtype and broadcastable pattern and the
414386
dimshuffled inputs.
415387
@@ -427,12 +399,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427399
if not difference:
428400
args.append(input)
429401
else:
430-
args.append(
431-
dim_shuffle(
432-
input.type.broadcastable,
433-
["x"] * difference + list(range(length)),
434-
)(input)
435-
)
402+
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
436403
inputs = args
437404

438405
# HERE: all the broadcast dims have the same length now
@@ -489,7 +456,7 @@ def make_node(self, *inputs):
489456
using DimShuffle.
490457
"""
491458
inputs = [as_tensor_variable(i) for i in inputs]
492-
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
459+
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
493460
outputs = [
494461
TensorType(dtype=dtype, shape=shape)()
495462
for dtype, shape in zip(out_dtypes, out_shapes)
@@ -634,7 +601,7 @@ def transform(r):
634601
res = pytensor.tensor.basic.constant(
635602
np.asarray(r.data), dtype=r.type.dtype
636603
)
637-
return DimShuffle((), ["x"] * nd)(res)
604+
return res.dimshuffle(["x"] * nd)
638605

639606
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
640607
if isinstance(new_r, list | tuple):
@@ -1707,13 +1674,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
17071674
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
17081675
if not batched_ndims:
17091676
return node.op.make_node(x)
1710-
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
1711-
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712-
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1677+
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1678+
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
17131679
new_order = list(range(batched_ndims)) + [
17141680
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
17151681
]
1716-
return DimShuffle(input_broadcastable, new_order).make_node(x)
1682+
return x.dimshuffle(new_order).owner
17171683

17181684

17191685
def get_normalized_batch_axes(

pytensor/tensor/extra_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
)
4242
from pytensor.tensor.math import max as pt_max
4343
from pytensor.tensor.math import sum as pt_sum
44-
from pytensor.tensor.shape import specify_broadcastable
4544
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4645
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
4746
from pytensor.tensor.variable import TensorVariable
@@ -609,11 +608,6 @@ def squeeze(x, axis=None):
609608
# Nothing could be squeezed
610609
return _x
611610

612-
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
613-
# We add a `specify_broadcastable` instead of raising.
614-
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
615-
_x = specify_broadcastable(_x, *non_broadcastable_axis)
616-
617611
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
618612

619613

pytensor/tensor/math.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3434
from pytensor.tensor.elemwise import (
3535
CAReduce,
36-
DimShuffle,
3736
Elemwise,
3837
get_normalized_batch_axes,
3938
scalar_elemwise,
@@ -2338,8 +2337,7 @@ def L_op(self, inp, out, grads):
23382337
else:
23392338
new_dims.append(i)
23402339
i += 1
2341-
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
2342-
gx = Elemwise(ps.second)(x, ds_op(gz))
2340+
gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims))
23432341
return [gx]
23442342

23452343
def R_op(self, inputs, eval_points):

pytensor/tensor/variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def dimshuffle(self, *pattern):
344344
"""
345345
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
346346
pattern = pattern[0]
347-
op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern)
348-
return op(self)
347+
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
348+
return ds_op(self)
349349

350350
def flatten(self, ndim=1):
351351
return pt.basic.flatten(self, ndim)

tests/tensor/test_fft.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,12 @@ def f_irfft(inp):
204204
pytensor.config.floatX
205205
)
206206
utt.verify_grad(f_irfft, [inputs_val], eps=eps)
207+
208+
def test_rfft_expanded_dims_grad(self):
209+
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
210+
def test_func(x):
211+
return fft.rfft(x[None, :])
212+
213+
rng = np.random.default_rng(213)
214+
inputs_val = rng.random((N,)).astype(pytensor.config.floatX)
215+
utt.verify_grad(test_func, [inputs_val], rng=rng)

0 commit comments

Comments
 (0)