diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index 9bea7b87e8..98024f28ca 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -12,7 +12,7 @@ from pytensor.graph.replace import _vectorize_node from pytensor.scalar import ScalarVariable from pytensor.tensor import TensorLike, as_tensor_variable -from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split +from pytensor.tensor.basic import infer_static_shape, join, split from pytensor.tensor.math import prod from pytensor.tensor.type import tensor from pytensor.tensor.variable import TensorVariable @@ -24,10 +24,7 @@ class JoinDims(Op): - __props__ = ( - "start_axis", - "n_axes", - ) + __props__ = ("start_axis", "n_axes") view_map = {0: [0]} def __init__(self, start_axis: int, n_axes: int): @@ -55,6 +52,11 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] static_shapes = x.type.shape axis_range = self.axis_range + if (self.start_axis + self.n_axes) > x.type.ndim: + raise ValueError( + f"JoinDims was asked to join dimensions {self.start_axis} to {self.n_axes}, " + f"but input {x} has only {x.type.ndim} dimensions." + ) joined_shape = ( int(np.prod([static_shapes[i] for i in axis_range])) @@ -69,9 +71,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] def infer_shape(self, fgraph, node, shapes): [input_shape] = shapes - axis_range = self.axis_range - - joined_shape = prod([input_shape[i] for i in axis_range]) + joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int) return [self.output_shapes(input_shape, joined_shape)] def perform(self, node, inputs, outputs): @@ -98,23 +98,24 @@ def L_op(self, inputs, outputs, output_grads): @_vectorize_node.register(JoinDims) def _vectorize_joindims(op, node, x): [old_x] = node.inputs - batched_ndims = x.type.ndim - old_x.type.ndim - start_axis = op.start_axis - n_axes = op.n_axes - - return JoinDims(start_axis + batched_ndims, n_axes).make_node(x) + return JoinDims(op.start_axis + batched_ndims, op.n_axes).make_node(x) -def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable: +def join_dims( + x: TensorLike, start_axis: int = 0, n_axes: int | None = None +) -> TensorVariable: """Join consecutive dimensions of a tensor into a single dimension. Parameters ---------- x : TensorLike The input tensor. - axis : int or sequence of int, optional - The dimensions to join. If None, all dimensions are joined. + start_axis : int, default 0 + The axis from which to start joining dimensions + n_axes: int, optional. + The number of axis to join after `axis`. If `None` joins all remaining axis. + If 0, it inserts a new dimension of length 1. Returns ------- @@ -125,33 +126,32 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV -------- >>> import pytensor.tensor as pt >>> x = pt.tensor("x", shape=(2, 3, 4, 5)) - >>> y = pt.join_dims(x, axis=(1, 2)) + >>> y = pt.join_dims(x) + >>> y.type.shape + (120,) + >>> y = pt.join_dims(x, start_axis=1) + >>> y.type.shape + (2, 60) + >>> y = pt.join_dims(x, start_axis=1, n_axes=2) >>> y.type.shape (2, 12, 5) """ x = as_tensor_variable(x) + ndim = x.type.ndim - if axis is None: - axis = list(range(x.ndim)) - elif isinstance(axis, int): - axis = [axis] - elif not isinstance(axis, list | tuple): - raise TypeError("axis must be an int, a list/tuple of ints, or None") - - axis = normalize_axis_tuple(axis, x.ndim) + if start_axis < 0: + # We treat scalars as if they had a single axis + start_axis += max(1, ndim) - if len(axis) <= 1: - return x # type: ignore[unreachable] - - if np.diff(axis).max() > 1: - raise ValueError( - f"join_dims axis must be consecutive, got normalized axis: {axis}" + if not 0 <= start_axis <= ndim: + raise IndexError( + f"Axis {start_axis} is out of bounds for array of dimension {ndim}" ) - start_axis = min(axis) - n_axes = len(axis) + if n_axes is None: + n_axes = ndim - start_axis - return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value] + return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value] class SplitDims(Op): @@ -213,11 +213,11 @@ def connection_pattern(self, node): def L_op(self, inputs, outputs, output_grads): (x, _) = inputs (g_out,) = output_grads - n_axes = g_out.ndim - x.ndim + 1 - axis_range = list(range(self.axis, self.axis + n_axes)) - - return [join_dims(g_out, axis=axis_range), disconnected_type()] + return [ + join_dims(g_out, start_axis=self.axis, n_axes=n_axes), + disconnected_type(), + ] @_vectorize_node.register(SplitDims) @@ -230,14 +230,13 @@ def _vectorize_splitdims(op, node, x, shape): if as_tensor_variable(shape).type.ndim != 1: return vectorize_node_fallback(op, node, x, shape) - axis = op.axis - return SplitDims(axis=axis + batched_ndims).make_node(x, shape) + return SplitDims(axis=op.axis + batched_ndims).make_node(x, shape) def split_dims( x: TensorLike, shape: ShapeValueType | Sequence[ShapeValueType], - axis: int | None = None, + axis: int = 0, ) -> TensorVariable: """Split a dimension of a tensor into multiple dimensions. @@ -247,8 +246,8 @@ def split_dims( The input tensor. shape : int or sequence of int The new shape to split the specified dimension into. - axis : int, optional - The dimension to split. If None, the input is assumed to be 1D and axis 0 is used. + axis : int, default 0 + The dimension to split. Returns ------- @@ -259,22 +258,18 @@ def split_dims( -------- >>> import pytensor.tensor as pt >>> x = pt.tensor("x", shape=(6, 4, 6)) - >>> y = pt.split_dims(x, shape=(2, 3), axis=0) + >>> y = pt.split_dims(x, shape=(2, 3)) >>> y.type.shape (2, 3, 4, 6) + >>> y = pt.split_dims(x, shape=(2, 3), axis=-1) + >>> y.type.shape + (6, 4, 2, 3) """ x = as_tensor_variable(x) - - if axis is None: - if x.type.ndim != 1: - raise ValueError( - "split_dims can only be called with axis=None for 1d inputs" - ) - axis = 0 - else: - axis = normalize_axis_index(axis, x.ndim) + axis = normalize_axis_index(axis, x.ndim) # Convert scalar shape to 1d tuple (shape,) + # Which is basically a specify_shape if not isinstance(shape, Sequence): if isinstance(shape, TensorVariable | np.ndarray): if shape.ndim == 0: @@ -313,8 +308,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]: elif not isinstance(axes, Iterable): raise TypeError("axes must be an int, an iterable of ints, or None") - axes = tuple(axes) - if len(axes) == 0: raise ValueError("axes=[] is ambiguous; use None to ravel all") @@ -367,7 +360,7 @@ def find_gaps(s): def pack( - *tensors: TensorLike, axes: Sequence[int] | int | None = None + *tensors: TensorLike, keep_axes: Sequence[int] | int | None = None ) -> tuple[TensorVariable, list[TensorVariable]]: """ Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. @@ -401,8 +394,8 @@ def pack( Examples -------- - The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent - to ``join(0, *[t.ravel() for t in tensors])``: + The easiest way to understand pack is through examples. + The simplest case is using the default keep_axes=None, which is equivalent to ``concatenate([t.ravel() for t in tensors])``: .. code-block:: python import pytensor.tensor as pt @@ -410,19 +403,20 @@ def pack( x = pt.tensor("x", shape=(2, 3)) y = pt.tensor("y", shape=(4, 5, 6)) - packed_tensor, packed_shapes = pt.pack(x, y, axes=None) + packed_tensor, packed_shapes = pt.pack(x, y) # packed_tensor has shape (6 + 120,) == (126,) # packed_shapes is [(2, 3), (4, 5, 6)] - If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors - must have the same size along the preserved axis. For example, using axes=0: + If we want to preserve a single axis, we can use either positive or negative indexing. + Notice that all tensors must have the same size along the preserved axis. + For example, using keep_axes=0: .. code-block:: python import pytensor.tensor as pt x = pt.tensor("x", shape=(2, 3)) y = pt.tensor("y", shape=(2, 5, 6)) - packed_tensor, packed_shapes = pt.pack(x, y, axes=0) + packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=0) # packed_tensor has shape (2, 3 + 30) == (2, 33) # packed_shapes is [(3,), (5, 6)] @@ -434,7 +428,7 @@ def pack( x = pt.tensor("x", shape=(4, 2, 3)) y = pt.tensor("y", shape=(5, 2, 3)) - packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1)) + packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(-2, -1)) # packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3) # packed_shapes is [(4,), (5, @@ -445,13 +439,13 @@ def pack( x = pt.tensor("x", shape=(2, 4, 3)) y = pt.tensor("y", shape=(2, 5, 3)) - packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1)) + packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(0, -1)) # packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3) # packed_shapes is [(4,), (5,)] """ tensor_list = [as_tensor_variable(t) for t in tensors] - n_before, n_after, min_axes = _analyze_axes_list(axes) + n_before, n_after, min_axes = _analyze_axes_list(keep_axes) reshaped_tensors: list[Variable] = [] packed_shapes: list[TensorVariable] = [] @@ -462,24 +456,12 @@ def pack( if n_dim < min_axes: raise ValueError( f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " - f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." + f"but {keep_axes=} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." ) - n_after_packed = n_dim - n_after - packed_shapes.append(input_tensor.shape[n_before:n_after_packed]) - - if n_dim == min_axes: - # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern - # implied by the axes. - input_tensor = expand_dims(input_tensor, axis=n_before) - reshaped_tensors.append(input_tensor) - continue - - # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1, - # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the - # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the - # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing). - join_axes = range(n_before, n_after_packed) - joined = join_dims(input_tensor, tuple(join_axes)) + + n_packed = n_dim - n_after - n_before + packed_shapes.append(input_tensor.shape[n_before : n_before + n_packed]) + joined = join_dims(input_tensor, n_before, n_packed) reshaped_tensors.append(joined) return join(n_before, *reshaped_tensors), packed_shapes @@ -487,8 +469,8 @@ def pack( def unpack( packed_input: TensorLike, - axes: int | Sequence[int] | None, packed_shapes: Sequence[ShapeValueType], + keep_axes: int | Sequence[int] | None = None, ) -> list[TensorVariable]: """ Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. @@ -504,10 +486,10 @@ def unpack( ---------- packed_input : TensorLike The packed tensor to be unpacked. - axes : int, sequence of int, or None - Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used. packed_shapes : list of ShapeValueType A list containing the shapes of the raveled dimensions for each output tensor. + keep_axes : int, sequence of int, optional + Axes that were preserved during packing. Default is None Returns ------- @@ -515,28 +497,32 @@ def unpack( A list of unpacked tensors with their original shapes restored. """ packed_input = as_tensor_variable(packed_input) - - if axes is None: + if keep_axes is None: if packed_input.ndim != 1: raise ValueError( "unpack can only be called with keep_axis=None for 1d inputs" ) split_axis = 0 else: - axes = normalize_axis_tuple(axes, ndim=packed_input.ndim) + keep_axes = normalize_axis_tuple(keep_axes, ndim=packed_input.ndim) try: - [split_axis] = (i for i in range(packed_input.ndim) if i not in axes) + [split_axis] = (i for i in range(packed_input.ndim) if i not in keep_axes) except ValueError as err: raise ValueError( - "Unpack must have exactly one more dimension that implied by axes" + f"unpack input must have exactly one more dimension that implied by keep_axes. " + f"{packed_input} has {packed_input.type.ndim} dimensions, expected {len(keep_axes) + 1}" ) from err - split_inputs = split( - packed_input, - splits_size=[prod(shape, dtype=int) for shape in packed_shapes], - n_splits=len(packed_shapes), - axis=split_axis, - ) + n_splits = len(packed_shapes) + if n_splits == 1: + # If there is only one tensor to unpack, no need to split + split_inputs = [packed_input] + else: + split_inputs = split( + packed_input, + splits_size=[prod(shape, dtype=int) for shape in packed_shapes], + axis=split_axis, + ) return [ split_dims(inp, shape, split_axis) diff --git a/pytensor/tensor/rewriting/reshape.py b/pytensor/tensor/rewriting/reshape.py index ea3ff6d9c9..ab330551f3 100644 --- a/pytensor/tensor/rewriting/reshape.py +++ b/pytensor/tensor/rewriting/reshape.py @@ -34,8 +34,9 @@ def local_join_dims_to_reshape(fgraph, node): """ (x,) = node.inputs - start_axis = node.op.start_axis - n_axes = node.op.n_axes + op = node.op + start_axis = op.start_axis + n_axes = op.n_axes output_shape = [ *x.shape[:start_axis], diff --git a/tests/tensor/rewriting/test_reshape.py b/tests/tensor/rewriting/test_reshape.py index 59408f40fd..d18e0b6419 100644 --- a/tests/tensor/rewriting/test_reshape.py +++ b/tests/tensor/rewriting/test_reshape.py @@ -21,7 +21,7 @@ def test_local_split_dims_to_reshape(): def test_local_join_dims_to_reshape(): x = tensor("x", shape=(2, 2, 5, 1, 3)) - x_join = join_dims(x, axis=(1, 2, 3)) + x_join = join_dims(x, start_axis=1, n_axes=3) fg = FunctionGraph(inputs=[x], outputs=[x_join]) diff --git a/tests/tensor/test_reshape.py b/tests/tensor/test_reshape.py index 65da41c051..67a9422a58 100644 --- a/tests/tensor/test_reshape.py +++ b/tests/tensor/test_reshape.py @@ -20,32 +20,37 @@ def test_join_dims(): rng = np.random.default_rng() x = pt.tensor("x", shape=(2, 3, 4, 5)) - assert join_dims(x, axis=(0, 1)).type.shape == (6, 4, 5) - assert join_dims(x, axis=(1, 2)).type.shape == (2, 12, 5) - assert join_dims(x, axis=(-1, -2)).type.shape == (2, 3, 20) + assert join_dims(x).type.shape == (120,) + assert join_dims(x, n_axes=1).type.shape == (2, 3, 4, 5) + assert join_dims(x, n_axes=0).type.shape == (1, 2, 3, 4, 5) - assert join_dims(x, axis=()).type.shape == (2, 3, 4, 5) - assert join_dims(x, axis=(2,)).type.shape == (2, 3, 4, 5) + assert join_dims(x, n_axes=2).type.shape == (6, 4, 5) + assert join_dims(x, start_axis=1, n_axes=2).type.shape == (2, 12, 5) + assert join_dims(x, start_axis=-3, n_axes=2).type.shape == (2, 12, 5) + assert join_dims(x, start_axis=2).type.shape == (2, 3, 20) + + with pytest.raises( + IndexError, + match=r"Axis 5 is out of bounds for array of dimension 4", + ): + join_dims(x, start_axis=5) with pytest.raises( ValueError, - match=r"join_dims axis must be consecutive, got normalized axis: \(0, 2\)", + match=r"JoinDims was asked to join dimensions 0 to 5, but input x has only 4 dimensions.", ): - _ = join_dims(x, axis=(0, 2)).type.shape == (8, 3, 5) + join_dims(x, n_axes=5) - x_joined = join_dims(x, axis=(1, 2)) x_value = rng.normal(size=(2, 3, 4, 5)).astype(config.floatX) - - fn = function([x], x_joined, mode="FAST_COMPILE") - - x_joined_value = fn(x_value) - np.testing.assert_allclose(x_joined_value, x_value.reshape(2, 12, 5)) - - assert join_dims(x, axis=(1,)).eval({x: x_value}).shape == (2, 3, 4, 5) - assert join_dims(x, axis=()).eval({x: x_value}).shape == (2, 3, 4, 5) + np.testing.assert_allclose( + join_dims(x, start_axis=1, n_axes=2).eval({x: x_value}), + x_value.reshape(2, 12, 5), + ) + assert join_dims(x, 1, n_axes=1).eval({x: x_value}).shape == (2, 3, 4, 5) + assert join_dims(x, 1, n_axes=0).eval({x: x_value}).shape == (2, 1, 3, 4, 5) x = pt.tensor("x", shape=(3, 5)) - x_joined = join_dims(x, axis=(0, 1)) + x_joined = join_dims(x) x_batched = pt.tensor("x_batched", shape=(10, 3, 5)) x_joined_batched = vectorize_graph(x_joined, {x: x_batched}) @@ -54,7 +59,7 @@ def test_join_dims(): x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX) assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15) - utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value]) + utt.verify_grad(lambda x: join_dims(x, start_axis=1, n_axes=2), [x_value]) @pytest.mark.parametrize( @@ -117,9 +122,9 @@ def test_make_replacements_with_pack_unpack(): loss = (x + y.sum() + z.sum()) ** 2 - flat_packed, packed_shapes = pack(x, y, z, axes=None) + flat_packed, packed_shapes = pack(x, y, z) new_input = flat_packed.type() - new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes) + new_outputs = unpack(new_input, packed_shapes=packed_shapes) loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) rewrite_graph(loss, include=("ShapeOpt", "canonicalize")) @@ -198,7 +203,7 @@ def test_pack_basic(self): } # Simple case, reduce all axes, equivalent to einops '*' - packed_tensor, packed_shapes = pack(x, y, z, axes=None) + packed_tensor, packed_shapes = pack(x, y, z) assert packed_tensor.type.shape == (15,) for tensor, packed_shape in zip([x, y, z], packed_shapes): assert packed_shape.type.shape == (tensor.ndim,) @@ -211,9 +216,9 @@ def test_pack_basic(self): # x is scalar, so pack will raise: with pytest.raises( ValueError, - match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but axes=0 assumes at least 1 dimension\.", + match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but keep_axes=0 assumes at least 1 dimension\.", ): - pack(x, y, z, axes=0) + pack(x, y, z, keep_axes=0) # With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs x = pt.tensor("x", shape=(3,)) @@ -224,13 +229,13 @@ def test_pack_basic(self): match=r"all input array dimensions other than the specified `axis` \(1\) must match exactly, or be unknown " r"\(None\), but along dimension 0, the inputs shapes are incompatible: \[3 5 3\]", ): - packed_tensor, packed_shapes = pack(x, y, z, axes=0) + packed_tensor, packed_shapes = pack(x, y, z, keep_axes=0) packed_tensor.eval(input_dict) # Valid case, preserve first axis, equivalent to einops 'i *' y = pt.tensor("y", shape=(3, 5)) z = pt.tensor("z", shape=(3, 3, 3)) - packed_tensor, packed_shapes = pack(x, y, z, axes=0) + packed_tensor, packed_shapes = pack(x, y, z, keep_axes=0) input_dict = { variable.name: np.zeros(variable.type.shape, dtype=config.floatX) for variable in [x, y, z] @@ -253,10 +258,10 @@ def test_pack_basic(self): ValueError, match=r"Positive axes must be contiguous", ): - pack(x, y, z, axes=[0, 3]) + pack(x, y, z, keep_axes=[0, 3]) z = pt.tensor("z", shape=(3, 1, 7, 2)) - packed_tensor, packed_shapes = pack(x, y, z, axes=[0, -1]) + packed_tensor, packed_shapes = pack(x, y, z, keep_axes=[0, -1]) input_dict = { variable.name: np.zeros(variable.type.shape, dtype=config.floatX) for variable in [x, y, z] @@ -277,8 +282,8 @@ def test_pack_unpack_round_trip(self, axes): y = pt.tensor("y", shape=(3, 3, 5)) z = pt.tensor("z", shape=(1, 3, 5)) - flat_packed, packed_shapes = pack(x, y, z, axes=axes) - new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes) + flat_packed, packed_shapes = pack(x, y, z, keep_axes=axes) + new_outputs = unpack(flat_packed, packed_shapes=packed_shapes, keep_axes=axes) fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") @@ -289,13 +294,19 @@ def test_pack_unpack_round_trip(self, axes): output_vals = fn(**input_dict) for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): - np.testing.assert_allclose(input_val, output_val) - - -def test_unpack_connection(): - x = pt.vector("x") - d0 = pt.scalar("d0", dtype=int) - d1 = pt.scalar("d1", dtype=int) - x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1]) - out = x0.sum() + x1.sum() - assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]] + np.testing.assert_allclose(input_val, output_val, strict=True) + + def test_single_input(self): + x = pt.matrix("x", shape=(2, 5)) + packed_x, packed_shapes = pt.pack(x) + assert packed_x.type.shape == (10,) + [x_again] = unpack(packed_x, packed_shapes) + assert x_again.type.shape == (2, 5) + + def test_unpack_connection(self): + x = pt.vector("x") + d0 = pt.scalar("d0", dtype=int) + d1 = pt.scalar("d1", dtype=int) + x0, x1 = pt.unpack(x, packed_shapes=[d0, d1]) + out = x0.sum() + x1.sum() + assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]