Skip to content

Commit 47f2dc9

Browse files
committed
Only require input_ndim and not input_broadcastable in DimShuffle
1 parent 6b3b818 commit 47f2dc9

File tree

24 files changed

+130
-181
lines changed

24 files changed

+130
-181
lines changed

pytensor/sparse/sandbox/sp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytensor.tensor.math import dot
2020
from pytensor.tensor.math import max as pt_max
2121
from pytensor.tensor.shape import reshape
22-
from pytensor.tensor.subtensor import DimShuffle
2322

2423

2524
def register_specialize(lopt, *tags, **kwargs):
@@ -375,7 +374,7 @@ def convolve(
375374
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
376375
)
377376
tensout = reshape(output, newshp, ndim=3)
378-
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout)
377+
output = tensout.transpose(0, 2, 1)
379378
if flatten:
380379
output = pt.flatten(output, 2)
381380

@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
443442
)
444443
out2 = reshape(out1, pshape, ndim=3)
445444

446-
out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2)
445+
out3 = out2.transpose(0, 2, 1)
447446

448447
return pt.flatten(out3, 2), outshp

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
20422042
# No-op
20432043
return _x
20442044

2045-
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
2045+
ret = _x.dimshuffle(axes)
20462046

20472047
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
20482048
ret.name = _x.name + ".T"
@@ -3518,7 +3518,7 @@ def grad(self, inp, grads):
35183518
newdims.append(i)
35193519
i += 1
35203520

3521-
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
3521+
gx = gx.dimshuffle(newdims)
35223522
assert gx.type.ndim == x.type.ndim
35233523
assert all(
35243524
s1 == s2

pytensor/tensor/elemwise.py

Lines changed: 64 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
from copy import copy
23
from textwrap import dedent
4+
from typing import Literal
35

46
import numpy as np
57
from numpy.core.numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
5456
5557
Parameters
5658
----------
57-
input_broadcastable
58-
The expected broadcastable pattern of the input
59+
input_ndim
60+
The expected number of dimension of the input
5961
new_order
6062
A list representing the relationship between the input's
6163
dimensions and the output's dimensions. Each element of the
6264
list can either be an index or 'x'. Indices must be encoded
6365
as python integers, not pytensor symbolic integers.
64-
inplace : bool, optional
65-
If True (default), the output will be a view of the input.
66+
Missing indexes correspond to drop dimensions.
6667
6768
Notes
6869
-----
@@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
7778
7879
.. code-block:: python
7980
80-
DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
81+
DimShuffle(input_ndim=3, new_order=["x", 2, "x", 0, 1])
8182
82-
This `Op` will only work on 3d tensors with no broadcastable
83-
dimensions. The first dimension will be broadcastable,
83+
This `Op` will only work on 3d tensors.
84+
The first dimension of the output will be broadcastable,
8485
then we will have the third dimension of the input tensor as
8586
the second of the resulting tensor, etc. If the tensor has
8687
shape (20, 30, 40), the resulting tensor will have dimensions
8788
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
8889
8990
.. code-block:: python
9091
91-
DimShuffle((True, False), [1])
92+
DimShuffle(input_ndim=2, new_order=[1])
9293
93-
This `Op` will only work on 2d tensors with the first dimension
94-
broadcastable.
95-
The second dimension of the input tensor will be the first dimension of
96-
the resulting tensor.
97-
If the tensor has shape (1, 20), the resulting tensor will have shape
98-
(20, ).
94+
This `Op` will only work on 2d tensors with the first dimension broadcastable.
95+
The second dimension of the input tensor will be the first dimension of the resulting tensor.
96+
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
9997
10098
Examples
10199
--------
102100
.. code-block:: python
103101
104-
DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
105-
DimShuffle((False, False), [0, 1]) # identity
106-
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107-
DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
108-
# (N to 1xN)
109-
DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
110-
# (N to Nx1)
111-
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112-
DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
113-
DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
114-
115-
The reordering of the dimensions can be done with the numpy.transpose
116-
function.
117-
Adding, subtracting dimensions can be done with reshape.
102+
DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
103+
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104+
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105+
DimShuffle(input_ndim=1, new_order=["x", 0]) # make a row out of a 1d vector (N to 1xN)
106+
DimShuffle(input_ndim=1, new_order=[0, "x"]) # make a column out of a 1d vector (N to Nx1)
107+
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
108+
DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
109+
DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
118110
111+
Notes
112+
-----
113+
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
114+
and numpy.reshape for subtracting and adding broadcastable dimensions.
119115
"""
120116

121117
_f16_ok = True
122118
check_input = False
123-
__props__ = ("input_broadcastable", "new_order", "inplace")
119+
__props__ = ("input_ndim", "new_order", "inplace")
124120
c_func_file = "c_code/dimshuffle.c"
125121
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126122

@@ -133,16 +129,14 @@ def params_type(self):
133129
inplace=scalar_bool,
134130
)
135131

136-
def __init__(self, input_broadcastable, new_order):
132+
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
137133
super().__init__([self.c_func_file], self.c_func_name)
138134

139-
self.input_broadcastable = tuple(input_broadcastable)
140-
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
141-
raise ValueError(
142-
f"input_broadcastable must be boolean, {self.input_broadcastable}"
143-
)
144-
self.new_order = tuple(new_order)
135+
if not isinstance(input_ndim, int):
136+
raise TypeError(f"input_ndim must be an integer, got {type(int)}")
145137

138+
self.input_ndim = input_ndim
139+
self.new_order = tuple(new_order)
146140
self.inplace = True
147141

148142
for i, j in enumerate(new_order):
@@ -152,10 +146,10 @@ def __init__(self, input_broadcastable, new_order):
152146
"DimShuffle indices must be Python ints; got "
153147
f"{j} of type {type(j)}."
154148
)
155-
if j >= len(input_broadcastable):
149+
if j >= input_ndim:
156150
raise ValueError(
157151
f"new_order[{i}] is {j}, but the input only has "
158-
f"{len(input_broadcastable)} axes."
152+
f"{input_ndim} axes."
159153
)
160154
if j in new_order[(i + 1) :]:
161155
raise ValueError(
@@ -164,19 +158,7 @@ def __init__(self, input_broadcastable, new_order):
164158
)
165159

166160
# List of input dimensions to drop
167-
drop = []
168-
for i, b in enumerate(input_broadcastable):
169-
if i not in new_order:
170-
# We want to drop this dimension because it's not a value in
171-
# `new_order`
172-
if b == 1:
173-
drop.append(i)
174-
else:
175-
# We cannot drop non-broadcastable dimensions
176-
raise ValueError(
177-
"Cannot drop a non-broadcastable dimension: "
178-
f"{input_broadcastable}, {new_order}"
179-
)
161+
drop = [i for i in range(input_ndim) if i not in new_order]
180162

181163
# This is the list of the original dimensions that we keep
182164
self.shuffle = [x for x in new_order if x != "x"]
@@ -186,7 +168,6 @@ def __init__(self, input_broadcastable, new_order):
186168
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
187169
self.drop = drop
188170

189-
input_ndim = len(input_broadcastable)
190171
self.is_left_expand_dims = self.augment and (
191172
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
192173
)
@@ -204,30 +185,29 @@ def __setstate__(self, state):
204185
# Let's just build the ExternalCOp.
205186
super().__init__([self.c_func_file], self.c_func_name)
206187

207-
def make_node(self, _input):
208-
input = as_tensor_variable(_input)
209-
ib = tuple(s == 1 for s in input.type.shape)
210-
if ib != self.input_broadcastable:
211-
if len(ib) != len(self.input_broadcastable):
188+
def make_node(self, inp):
189+
input = as_tensor_variable(inp)
190+
if input.type.ndim != self.input_ndim:
191+
raise TypeError(
192+
"The number of dimensions of the input is incorrect for this op. "
193+
f"Expected {self.input_ndim}, got {input.type.ndim}."
194+
)
195+
196+
input_static_shape = input.type.shape
197+
198+
# Runtime check for invalid drop
199+
for d in self.drop:
200+
if input_static_shape[d] not in (1, None):
212201
raise TypeError(
213-
"The number of dimensions of the "
214-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
202+
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
215203
)
216-
for expected, b in zip(self.input_broadcastable, ib):
217-
if expected and not b:
218-
raise TypeError(
219-
"The broadcastable pattern of the "
220-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
221-
)
222-
# else, expected == b or not expected and b
223-
# Both case are good.
224204

225205
out_static_shape = []
226206
for dim_idx in self.new_order:
227207
if dim_idx == "x":
228208
out_static_shape.append(1)
229209
else:
230-
out_static_shape.append(input.type.shape[dim_idx])
210+
out_static_shape.append(input_static_shape[dim_idx])
231211

232212
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
233213

@@ -254,12 +234,14 @@ def perform(self, node, inp, out):
254234
if not isinstance(res, np.ndarray | np.memmap):
255235
raise TypeError(res)
256236

237+
# Put dropped axis at end
257238
res = res.transpose(self.transposition)
258239

259-
shape = list(res.shape[: len(self.shuffle)])
240+
# Define new shape without dropped axis and including new ones
241+
new_shape = list(res.shape[: len(self.shuffle)])
260242
for augm in self.augment:
261-
shape.insert(augm, 1)
262-
res = res.reshape(shape)
243+
new_shape.insert(augm, 1)
244+
res = res.reshape(new_shape)
263245

264246
if not self.inplace:
265247
res = np.copy(res)
@@ -284,22 +266,15 @@ def R_op(self, inputs, eval_points):
284266
def grad(self, inp, grads):
285267
(x,) = inp
286268
(gz,) = grads
287-
gz = as_tensor_variable(gz)
288269
grad_order = ["x"] * x.type.ndim
289270
for i, v in enumerate(self.new_order):
290271
if v != "x":
291272
grad_order[v] = i
292-
# Do not make the DimShuffle inplace as an optimization at the
293-
# canonicalization optimization phase will remove the inplace.
294-
# The inplace will be reintroduced automatically later in the graph.
295-
if inp[0].dtype in discrete_dtypes:
296-
return [inp[0].zeros_like(dtype=config.floatX)]
273+
274+
if x.type.dtype in discrete_dtypes:
275+
return [x.zeros_like(dtype=config.floatX)]
297276
else:
298-
return [
299-
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
300-
Elemwise(scalar_identity)(gz)
301-
)
302-
]
277+
return [gz.dimshuffle(grad_order)]
303278

304279

305280
class DimShufflePrinter(Printer):
@@ -409,7 +384,7 @@ def __setstate__(self, d):
409384
self.nfunc = None
410385
self.inplace_pattern = frozendict(self.inplace_pattern)
411386

412-
def get_output_info(self, dim_shuffle, *inputs):
387+
def get_output_info(self, *inputs):
413388
"""Return the outputs dtype and broadcastable pattern and the
414389
dimshuffled inputs.
415390
@@ -427,12 +402,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427402
if not difference:
428403
args.append(input)
429404
else:
430-
args.append(
431-
dim_shuffle(
432-
input.type.broadcastable,
433-
["x"] * difference + list(range(length)),
434-
)(input)
435-
)
405+
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
436406
inputs = args
437407

438408
# HERE: all the broadcast dims have the same length now
@@ -489,7 +459,7 @@ def make_node(self, *inputs):
489459
using DimShuffle.
490460
"""
491461
inputs = [as_tensor_variable(i) for i in inputs]
492-
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
462+
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
493463
outputs = [
494464
TensorType(dtype=dtype, shape=shape)()
495465
for dtype, shape in zip(out_dtypes, out_shapes)
@@ -634,7 +604,7 @@ def transform(r):
634604
res = pytensor.tensor.basic.constant(
635605
np.asarray(r.data), dtype=r.type.dtype
636606
)
637-
return DimShuffle((), ["x"] * nd)(res)
607+
return res.dimshuffle(["x"] * nd)
638608

639609
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
640610
if isinstance(new_r, list | tuple):
@@ -1707,13 +1677,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
17071677
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
17081678
if not batched_ndims:
17091679
return node.op.make_node(x)
1710-
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
1711-
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712-
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1680+
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1681+
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
17131682
new_order = list(range(batched_ndims)) + [
17141683
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
17151684
]
1716-
return DimShuffle(input_broadcastable, new_order).make_node(x)
1685+
return x.dimshuffle(new_order).owner
17171686

17181687

17191688
def get_normalized_batch_axes(

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from pytensor.tensor.math import max as pt_max
4343
from pytensor.tensor.math import sum as pt_sum
44-
from pytensor.tensor.shape import Shape_i, specify_broadcastable
44+
from pytensor.tensor.shape import Shape_i
4545
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4646
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
4747
from pytensor.tensor.variable import TensorVariable
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
609609
# Nothing could be squeezed
610610
return _x
611611

612-
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
613-
# We add a `specify_broadcastable` instead of raising.
614-
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
615-
_x = specify_broadcastable(_x, *non_broadcastable_axis)
616-
617612
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
618613

619614

pytensor/tensor/inplace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytensor import printing
22
from pytensor.printing import pprint
3-
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise
3+
from pytensor.tensor.elemwise import scalar_elemwise
44

55

66
@scalar_elemwise
@@ -429,4 +429,4 @@ def hyp2f1_inplace(a, b, c, z):
429429
def transpose_inplace(x, **kwargs):
430430
"Perform a transpose on a tensor without copying the underlying storage"
431431
dims = list(range(x.ndim - 1, -1, -1))
432-
return DimShuffle(x.broadcastable, dims)(x)
432+
return x.dimshuffle(dims)

0 commit comments

Comments
 (0)