From 54c9e69eda927718efe7a32fd245c8b28cf5628b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 21 Dec 2025 11:35:35 -0600 Subject: [PATCH 1/2] Implement `L_Op` for `join_dims` and `split_dims` Improve type hints for `join_dims` and `split_dims` --- pytensor/tensor/reshape.py | 52 +++++++++++++++++++++++++++--------- tests/tensor/test_reshape.py | 7 ++++- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index c9a80ec244..5c2800a994 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -1,18 +1,19 @@ from collections.abc import Iterable, Sequence from itertools import pairwise -from typing import cast as type_cast import numpy as np from numpy.lib._array_utils_impl import normalize_axis_tuple from pytensor import Variable +from pytensor.gradient import DisconnectedType from pytensor.graph import Apply from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.math import prod -from pytensor.tensor.shape import ShapeValueType +from pytensor.tensor.shape import ShapeValueType, shape from pytensor.tensor.type import tensor from pytensor.tensor.variable import TensorVariable @@ -80,6 +81,18 @@ def perform(self, node, inputs, outputs): out[0] = x.reshape(output_shape) + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], + ) -> list[Variable]: + (x,) = inputs + (g_out,) = output_grads + + packed_shape = shape(x)[list(self.axis_range)] + return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)] + @_vectorize_node.register(JoinDims) def _vectorize_joindims(op, node, x): @@ -92,7 +105,7 @@ def _vectorize_joindims(op, node, x): return JoinDims(start_axis + batched_ndims, n_axes).make_node(x) -def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable: +def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> Variable: """Join consecutive dimensions of a tensor into a single dimension. Parameters @@ -137,10 +150,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV start_axis = min(axis) n_axes = len(axis) - return type_cast( - TensorVariable, - JoinDims(start_axis=start_axis, n_axes=n_axes)(x), - ) + return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value] class SplitDims(Op): @@ -191,6 +201,23 @@ def perform(self, node, inputs, outputs): out[0] = x.reshape(output_shape) + def connection_pattern(self, node): + return [[True], [False]] + + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], + ) -> list[Variable]: + (x, _) = inputs + (g_out,) = output_grads + + n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined] + axis_range = list(range(self.axis, self.axis + n_axes)) + + return [join_dims(g_out, axis=axis_range), DisconnectedType()()] + @_vectorize_node.register(SplitDims) def _vectorize_splitdims(op, node, x, shape): @@ -210,7 +237,7 @@ def split_dims( x: TensorLike, shape: ShapeValueType | Sequence[ShapeValueType], axis: int | None = None, -) -> TensorVariable: +) -> Variable: """Split a dimension of a tensor into multiple dimensions. Parameters @@ -253,13 +280,12 @@ def split_dims( # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes # (3, ) and (3, 3) to (3, 4) - return type_cast(TensorVariable, x.squeeze(axis=axis)) + return squeeze(x, axis=axis) # type: ignore[no-any-return] [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc] shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type] - split_op = SplitDims(axis=axis) - return type_cast(TensorVariable, split_op(x, shape)) + return SplitDims(axis=axis)(x, shape) # type: ignore[return-value] def _analyze_axes_list(axes) -> tuple[int, int, int]: @@ -430,7 +456,7 @@ def pack( n_before, n_after, min_axes = _analyze_axes_list(axes) - reshaped_tensors: list[TensorVariable] = [] + reshaped_tensors: list[Variable] = [] packed_shapes: list[ShapeValueType] = [] for i, input_tensor in enumerate(tensor_list): @@ -466,7 +492,7 @@ def unpack( packed_input: TensorLike, axes: int | Sequence[int] | None, packed_shapes: list[ShapeValueType], -) -> list[TensorVariable]: +) -> list[Variable]: """ Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. diff --git a/tests/tensor/test_reshape.py b/tests/tensor/test_reshape.py index 2eb8d0748f..c68eb7c31a 100644 --- a/tests/tensor/test_reshape.py +++ b/tests/tensor/test_reshape.py @@ -2,6 +2,7 @@ import pytest import pytensor +import tests.unittest_tools as utt from pytensor import config, function from pytensor import tensor as pt from pytensor.graph import rewrite_graph, vectorize_graph @@ -52,6 +53,8 @@ def test_join_dims(): x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX) assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15) + utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value]) + @pytest.mark.parametrize( "axis, shape, expected_shape", @@ -77,6 +80,8 @@ def test_split_dims(axis, shape, expected_shape): x_split_value = fn(x_value) np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape)) + utt.verify_grad(lambda x: split_dims(x, shape=shape, axis=axis), [x_value]) + x = pt.tensor("x", shape=(10,)) x_split = split_dims(x, shape=(5, 2), axis=0) x_batched = pt.tensor("x_batched", shape=(3, 10)) @@ -115,7 +120,7 @@ def test_make_replacements_with_pack_unpack(): new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes) loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) - rewrite_graph(loss, include=("ShapeOpt", "specialize")) + rewrite_graph(loss, include=("ShapeOpt", "canonicalize")) fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") From b79fa6315c933f289dd73c37d33d9cf99ed0733b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 2 Jan 2026 15:36:57 -0600 Subject: [PATCH 2/2] Feedback --- pytensor/tensor/reshape.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index 5c2800a994..f556af2a75 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -90,7 +90,8 @@ def L_op( (x,) = inputs (g_out,) = output_grads - packed_shape = shape(x)[list(self.axis_range)] + x_shape = shape(x) + packed_shape = [x_shape[i] for i in self.axis_range] return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)] @@ -105,19 +106,19 @@ def _vectorize_joindims(op, node, x): return JoinDims(start_axis + batched_ndims, n_axes).make_node(x) -def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> Variable: +def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable: """Join consecutive dimensions of a tensor into a single dimension. Parameters ---------- - x : Variable + x : TensorLike The input tensor. axis : int or sequence of int, optional The dimensions to join. If None, all dimensions are joined. Returns ------- - joined_x : Variable + joined_x : TensorVariable The reshaped tensor with joined dimensions. Examples @@ -237,7 +238,7 @@ def split_dims( x: TensorLike, shape: ShapeValueType | Sequence[ShapeValueType], axis: int | None = None, -) -> Variable: +) -> TensorVariable: """Split a dimension of a tensor into multiple dimensions. Parameters @@ -251,7 +252,7 @@ def split_dims( Returns ------- - split_x : Variable + split_x : TensorVariable The reshaped tensor with split dimensions. Examples @@ -384,7 +385,7 @@ def pack( Returns ------- - packed_tensor : TensorLike + packed_tensor : TensorVariable The packed tensor with specified axes preserved and others raveled. packed_shapes : list of ShapeValueType A list containing the shapes of the raveled dimensions for each input tensor. @@ -492,7 +493,7 @@ def unpack( packed_input: TensorLike, axes: int | Sequence[int] | None, packed_shapes: list[ShapeValueType], -) -> list[Variable]: +) -> list[TensorVariable]: """ Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. @@ -514,7 +515,7 @@ def unpack( Returns ------- - unpacked_tensors : list of TensorLike + unpacked_tensors : list of TensorVariable A list of unpacked tensors with their original shapes restored. """ packed_input = as_tensor_variable(packed_input)