Skip to content

Commit 6a5ddb6

Browse files
committed
Implement join_dims as mirror of split_dims
Mainly, joining 0 axes is equivalent to inserting a new dimension. This is the mirror of how splitting a single axis into an empty shape is equivalent to squeezing it.
1 parent bbb9ac8 commit 6a5ddb6

File tree

4 files changed

+83
-94
lines changed

4 files changed

+83
-94
lines changed

pytensor/tensor/reshape.py

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph.replace import _vectorize_node
1313
from pytensor.scalar import ScalarVariable
1414
from pytensor.tensor import TensorLike, as_tensor_variable
15-
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
15+
from pytensor.tensor.basic import infer_static_shape, join, split
1616
from pytensor.tensor.math import prod
1717
from pytensor.tensor.type import tensor
1818
from pytensor.tensor.variable import TensorVariable
@@ -24,10 +24,7 @@
2424

2525

2626
class JoinDims(Op):
27-
__props__ = (
28-
"start_axis",
29-
"n_axes",
30-
)
27+
__props__ = ("start_axis", "n_axes")
3128
view_map = {0: [0]}
3229

3330
def __init__(self, start_axis: int, n_axes: int):
@@ -55,6 +52,8 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
5552

5653
static_shapes = x.type.shape
5754
axis_range = self.axis_range
55+
if (self.start_axis + self.n_axes) > x.type.ndim:
56+
raise ValueError("JoinDims requested to join axes that are not available")
5857

5958
joined_shape = (
6059
int(np.prod([static_shapes[i] for i in axis_range]))
@@ -69,9 +68,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
6968

7069
def infer_shape(self, fgraph, node, shapes):
7170
[input_shape] = shapes
72-
axis_range = self.axis_range
73-
74-
joined_shape = prod([input_shape[i] for i in axis_range], dtype=int)
71+
joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int)
7572
return [self.output_shapes(input_shape, joined_shape)]
7673

7774
def perform(self, node, inputs, outputs):
@@ -98,23 +95,24 @@ def L_op(self, inputs, outputs, output_grads):
9895
@_vectorize_node.register(JoinDims)
9996
def _vectorize_joindims(op, node, x):
10097
[old_x] = node.inputs
101-
10298
batched_ndims = x.type.ndim - old_x.type.ndim
103-
start_axis = op.start_axis
104-
n_axes = op.n_axes
99+
return JoinDims(op.start_axis + batched_ndims, op.n_axes).make_node(x)
105100

106-
return JoinDims(start_axis + batched_ndims, n_axes).make_node(x)
107101

108-
109-
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable:
102+
def join_dims(
103+
x: TensorLike, start_axis: int = 0, n_axes: int | None = None
104+
) -> TensorVariable:
110105
"""Join consecutive dimensions of a tensor into a single dimension.
111106
112107
Parameters
113108
----------
114109
x : TensorLike
115110
The input tensor.
116-
axis : int or sequence of int, optional
117-
The dimensions to join. If None, all dimensions are joined.
111+
start_axis : int, default 0
112+
The axis from which to start joining dimensions
113+
n_axes: int, optional.
114+
The number of axis to join after `axis`. If `None` joins all remaining axis.
115+
If 0, it inserts a new dimension of length 1.
118116
119117
Returns
120118
-------
@@ -125,33 +123,31 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
125123
--------
126124
>>> import pytensor.tensor as pt
127125
>>> x = pt.tensor("x", shape=(2, 3, 4, 5))
128-
>>> y = pt.join_dims(x, axis=(1, 2))
126+
>>> y = pt.join_dims(x)
127+
>>> y.type.shape
128+
(120,)
129+
>>> y = pt.join_dims(x, start_axis=1)
130+
>>> y.type.shape
131+
(2, 60)
132+
>>> y = pt.join_dims(x, start_axis=1, n_axes=2)
129133
>>> y.type.shape
130134
(2, 12, 5)
131135
"""
132-
x = as_tensor_variable(x)
133-
134-
if axis is None:
135-
axis = list(range(x.ndim))
136-
elif isinstance(axis, int):
137-
axis = [axis]
138-
elif not isinstance(axis, list | tuple):
139-
raise TypeError("axis must be an int, a list/tuple of ints, or None")
136+
ndim = x.ndim
140137

141-
axis = normalize_axis_tuple(axis, x.ndim)
138+
if start_axis < 0:
139+
# We treat scalars as if they had a single axis
140+
start_axis += max(1, ndim)
142141

143-
if len(axis) <= 1:
144-
return x # type: ignore[unreachable]
145-
146-
if np.diff(axis).max() > 1:
147-
raise ValueError(
148-
f"join_dims axis must be consecutive, got normalized axis: {axis}"
142+
if not 0 <= start_axis <= ndim:
143+
raise IndexError(
144+
f"Axis {start_axis} is out of bounds for array of dimension {ndim}"
149145
)
150146

151-
start_axis = min(axis)
152-
n_axes = len(axis)
147+
if n_axes is None:
148+
n_axes = ndim - start_axis
153149

154-
return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value]
150+
return JoinDims(start_axis, n_axes)(x)
155151

156152

157153
class SplitDims(Op):
@@ -213,11 +209,11 @@ def connection_pattern(self, node):
213209
def L_op(self, inputs, outputs, output_grads):
214210
(x, _) = inputs
215211
(g_out,) = output_grads
216-
217212
n_axes = g_out.ndim - x.ndim + 1
218-
axis_range = list(range(self.axis, self.axis + n_axes))
219-
220-
return [join_dims(g_out, axis=axis_range), disconnected_type()]
213+
return [
214+
join_dims(g_out, start_axis=self.axis, n_axes=n_axes),
215+
disconnected_type(),
216+
]
221217

222218

223219
@_vectorize_node.register(SplitDims)
@@ -230,14 +226,13 @@ def _vectorize_splitdims(op, node, x, shape):
230226
if as_tensor_variable(shape).type.ndim != 1:
231227
return vectorize_node_fallback(op, node, x, shape)
232228

233-
axis = op.axis
234-
return SplitDims(axis=axis + batched_ndims).make_node(x, shape)
229+
return SplitDims(axis=op.axis + batched_ndims).make_node(x, shape)
235230

236231

237232
def split_dims(
238233
x: TensorLike,
239234
shape: ShapeValueType | Sequence[ShapeValueType],
240-
axis: int | None = None,
235+
axis: int = 0,
241236
) -> TensorVariable:
242237
"""Split a dimension of a tensor into multiple dimensions.
243238
@@ -247,8 +242,8 @@ def split_dims(
247242
The input tensor.
248243
shape : int or sequence of int
249244
The new shape to split the specified dimension into.
250-
axis : int, optional
251-
The dimension to split. If None, the input is assumed to be 1D and axis 0 is used.
245+
axis : int, default 0
246+
The dimension to split.
252247
253248
Returns
254249
-------
@@ -259,22 +254,18 @@ def split_dims(
259254
--------
260255
>>> import pytensor.tensor as pt
261256
>>> x = pt.tensor("x", shape=(6, 4, 6))
262-
>>> y = pt.split_dims(x, shape=(2, 3), axis=0)
257+
>>> y = pt.split_dims(x, shape=(2, 3))
263258
>>> y.type.shape
264259
(2, 3, 4, 6)
260+
>>> y = pt.split_dims(x, shape=(2, 3), axis=-1)
261+
>>> y.type.shape
262+
(6, 4, 2, 3)
265263
"""
266264
x = as_tensor_variable(x)
267-
268-
if axis is None:
269-
if x.type.ndim != 1:
270-
raise ValueError(
271-
"split_dims can only be called with axis=None for 1d inputs"
272-
)
273-
axis = 0
274-
else:
275-
axis = normalize_axis_index(axis, x.ndim)
265+
axis = normalize_axis_index(axis, x.ndim)
276266

277267
# Convert scalar shape to 1d tuple (shape,)
268+
# Which is basically a specify_shape
278269
if not isinstance(shape, Sequence):
279270
if isinstance(shape, TensorVariable | np.ndarray):
280271
if shape.ndim == 0:
@@ -313,8 +304,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
313304
elif not isinstance(axes, Iterable):
314305
raise TypeError("axes must be an int, an iterable of ints, or None")
315306

316-
axes = tuple(axes)
317-
318307
if len(axes) == 0:
319308
raise ValueError("axes=[] is ambiguous; use None to ravel all")
320309

@@ -464,22 +453,10 @@ def pack(
464453
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
465454
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
466455
)
467-
n_after_packed = n_dim - n_after
468-
packed_shapes.append(input_tensor.shape[n_before:n_after_packed])
469-
470-
if n_dim == min_axes:
471-
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
472-
# implied by the axes.
473-
input_tensor = expand_dims(input_tensor, axis=n_before)
474-
reshaped_tensors.append(input_tensor)
475-
continue
476-
477-
# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
478-
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
479-
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
480-
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
481-
join_axes = range(n_before, n_after_packed)
482-
joined = join_dims(input_tensor, tuple(join_axes))
456+
457+
n_packed = n_dim - n_after - n_before
458+
packed_shapes.append(input_tensor.shape[n_before : n_before + n_packed])
459+
joined = join_dims(input_tensor, n_before, n_packed)
483460
reshaped_tensors.append(joined)
484461

485462
return join(n_before, *reshaped_tensors), packed_shapes

pytensor/tensor/rewriting/reshape.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def local_join_dims_to_reshape(fgraph, node):
3434
"""
3535

3636
(x,) = node.inputs
37-
start_axis = node.op.start_axis
38-
n_axes = node.op.n_axes
37+
op = node.op
38+
start_axis = op.start_axis
39+
n_axes = op.n_axes
3940

4041
output_shape = [
4142
*x.shape[:start_axis],

tests/tensor/rewriting/test_reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_local_split_dims_to_reshape():
2121

2222
def test_local_join_dims_to_reshape():
2323
x = tensor("x", shape=(2, 2, 5, 1, 3))
24-
x_join = join_dims(x, axis=(1, 2, 3))
24+
x_join = join_dims(x, start_axis=1, n_axes=3)
2525

2626
fg = FunctionGraph(inputs=[x], outputs=[x_join])
2727

tests/tensor/test_reshape.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,43 @@ def test_join_dims():
2020
rng = np.random.default_rng()
2121

2222
x = pt.tensor("x", shape=(2, 3, 4, 5))
23-
assert join_dims(x, axis=(0, 1)).type.shape == (6, 4, 5)
24-
assert join_dims(x, axis=(1, 2)).type.shape == (2, 12, 5)
25-
assert join_dims(x, axis=(-1, -2)).type.shape == (2, 3, 20)
23+
assert join_dims(x).type.shape == (120,)
24+
assert join_dims(x, n_axes=1).type.shape == (2, 3, 4, 5)
25+
assert join_dims(x, n_axes=0).type.shape == (1, 2, 3, 4, 5)
2626

27-
assert join_dims(x, axis=()).type.shape == (2, 3, 4, 5)
28-
assert join_dims(x, axis=(2,)).type.shape == (2, 3, 4, 5)
27+
assert join_dims(x, n_axes=2).type.shape == (6, 4, 5)
28+
assert join_dims(x, start_axis=1, n_axes=2).type.shape == (2, 12, 5)
29+
assert join_dims(x, start_axis=-3, n_axes=2).type.shape == (2, 12, 5)
30+
assert join_dims(x, start_axis=2).type.shape == (2, 3, 20)
31+
32+
with pytest.raises(
33+
IndexError,
34+
match=r"Axis 5 is out of bounds for array of dimension 4",
35+
):
36+
join_dims(x, start_axis=5)
2937

3038
with pytest.raises(
3139
ValueError,
32-
match=r"join_dims axis must be consecutive, got normalized axis: \(0, 2\)",
40+
match=r"JoinDims requested to join axes that are not available",
3341
):
34-
_ = join_dims(x, axis=(0, 2)).type.shape == (8, 3, 5)
42+
join_dims(x, n_axes=5)
3543

36-
x_joined = join_dims(x, axis=(1, 2))
3744
x_value = rng.normal(size=(2, 3, 4, 5)).astype(config.floatX)
38-
39-
fn = function([x], x_joined, mode="FAST_COMPILE")
40-
41-
x_joined_value = fn(x_value)
42-
np.testing.assert_allclose(x_joined_value, x_value.reshape(2, 12, 5))
43-
44-
assert join_dims(x, axis=(1,)).eval({x: x_value}).shape == (2, 3, 4, 5)
45-
assert join_dims(x, axis=()).eval({x: x_value}).shape == (2, 3, 4, 5)
45+
np.testing.assert_allclose(
46+
join_dims(x, start_axis=1, n_axes=2).eval({x: x_value}),
47+
x_value.reshape(2, 12, 5),
48+
)
49+
assert join_dims(x, start_axis=1, n_axes=1).eval({x: x_value}).shape == (2, 3, 4, 5)
50+
assert join_dims(x, start_axis=1, n_axes=0).eval({x: x_value}).shape == (
51+
2,
52+
1,
53+
3,
54+
4,
55+
5,
56+
)
4657

4758
x = pt.tensor("x", shape=(3, 5))
48-
x_joined = join_dims(x, axis=(0, 1))
59+
x_joined = join_dims(x)
4960
x_batched = pt.tensor("x_batched", shape=(10, 3, 5))
5061
x_joined_batched = vectorize_graph(x_joined, {x: x_batched})
5162

@@ -54,7 +65,7 @@ def test_join_dims():
5465
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
5566
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)
5667

57-
utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value])
68+
utt.verify_grad(lambda x: join_dims(x, start_axis=1, n_axes=2), [x_value])
5869

5970

6071
@pytest.mark.parametrize(
@@ -289,7 +300,7 @@ def test_pack_unpack_round_trip(self, axes):
289300
output_vals = fn(**input_dict)
290301

291302
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
292-
np.testing.assert_allclose(input_val, output_val)
303+
np.testing.assert_allclose(input_val, output_val, strict=True)
293304

294305

295306
def test_unpack_connection():

0 commit comments

Comments
 (0)