diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 98e48261ae..3099c07e77 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2254,18 +2254,19 @@ def infer_shape(self, fgraph, node, in_shapes): out_shapes.append(temp) return out_shapes + def connection_pattern(self, node): + n_out = len(node.outputs) + return [ + [True] * n_out, + [True] * n_out, + [False] * n_out, + ] + def L_op(self, inputs, outputs, g_outputs): """Join the gradients along the axis that was used to split x.""" - _x, axis, n = inputs + _x, axis, _n = inputs - # If all the output gradients are disconnected, then so are the inputs - if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): - return [ - DisconnectedType()(), - grad_undefined(self, 1, axis), - grad_undefined(self, 2, n), - ] - # Else, we have to make them zeros before joining them + # We have to convert disconnected outputs to zeros before joining them new_g_outputs = [] for o, g in zip(outputs, g_outputs, strict=True): if isinstance(g.type, DisconnectedType): @@ -2276,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs): return [ join(axis, *new_g_outputs), grad_undefined(self, 1, axis), - grad_undefined(self, 2, n), + DisconnectedType()(), ] def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index f556af2a75..78ac838d5c 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -1,23 +1,28 @@ from collections.abc import Iterable, Sequence from itertools import pairwise +from typing import TypeAlias import numpy as np -from numpy.lib._array_utils_impl import normalize_axis_tuple +from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple from pytensor import Variable from pytensor.gradient import DisconnectedType from pytensor.graph import Apply from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node +from pytensor.scalar import ScalarVariable from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split -from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.math import prod -from pytensor.tensor.shape import ShapeValueType, shape from pytensor.tensor.type import tensor from pytensor.tensor.variable import TensorVariable +ShapeValueType: TypeAlias = ( + int | np.integer | ScalarVariable | TensorVariable | np.ndarray +) + + class JoinDims(Op): __props__ = ( "start_axis", @@ -81,16 +86,11 @@ def perform(self, node, inputs, outputs): out[0] = x.reshape(output_shape) - def L_op( - self, - inputs: Sequence[Variable], - outputs: Sequence[Variable], - output_grads: Sequence[Variable], - ) -> list[Variable]: + def L_op(self, inputs, outputs, output_grads): (x,) = inputs (g_out,) = output_grads - x_shape = shape(x) + x_shape = x.shape packed_shape = [x_shape[i] for i in self.axis_range] return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)] @@ -163,13 +163,18 @@ def __init__(self, axis: int): raise ValueError("SplitDims axis must be non-negative") self.axis = axis - def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] + def make_node(self, x, shape): x = as_tensor_variable(x) - shape = as_tensor_variable(shape, dtype=int, ndim=1) + shape = as_tensor_variable(shape, dtype=int) if shape.type.numpy_dtype.kind not in "iu": raise TypeError("shape must be an integer tensor") + if shape.type.ndim != 1: + raise TypeError( + f"shape must be a 1-D tensor, got {shape} with {shape.type.ndim} dimensions" + ) + axis = self.axis _, constant_shape = infer_static_shape(shape) @@ -205,16 +210,11 @@ def perform(self, node, inputs, outputs): def connection_pattern(self, node): return [[True], [False]] - def L_op( - self, - inputs: Sequence[Variable], - outputs: Sequence[Variable], - output_grads: Sequence[Variable], - ) -> list[Variable]: + def L_op(self, inputs, outputs, output_grads): (x, _) = inputs (g_out,) = output_grads - n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined] + n_axes = g_out.ndim - x.ndim + 1 axis_range = list(range(self.axis, self.axis + n_axes)) return [join_dims(g_out, axis=axis_range), DisconnectedType()()] @@ -266,25 +266,21 @@ def split_dims( x = as_tensor_variable(x) if axis is None: - if x.ndim != 1: + if x.type.ndim != 1: raise ValueError( "split_dims can only be called with axis=None for 1d inputs" ) axis = 0 - - if isinstance(shape, int): - shape = [shape] else: - shape = list(shape) # type: ignore[arg-type] - - if not shape: - # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for - # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes - # (3, ) and (3, 3) to (3, 4) - return squeeze(x, axis=axis) # type: ignore[no-any-return] + axis = normalize_axis_index(axis, x.ndim) - [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc] - shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type] + # Convert scalar shape to 1d tuple (shape,) + if not isinstance(shape, Sequence): + if isinstance(shape, TensorVariable | np.ndarray): + if shape.ndim == 0: + shape = (shape,) + elif isinstance(shape, int | np.integer | ScalarVariable): + shape = (shape,) return SplitDims(axis=axis)(x, shape) # type: ignore[return-value] @@ -372,7 +368,7 @@ def find_gaps(s): def pack( *tensors: TensorLike, axes: Sequence[int] | int | None = None -) -> tuple[TensorVariable, list[ShapeValueType]]: +) -> tuple[TensorVariable, list[TensorVariable]]: """ Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. @@ -458,7 +454,7 @@ def pack( n_before, n_after, min_axes = _analyze_axes_list(axes) reshaped_tensors: list[Variable] = [] - packed_shapes: list[ShapeValueType] = [] + packed_shapes: list[TensorVariable] = [] for i, input_tensor in enumerate(tensor_list): n_dim = input_tensor.ndim @@ -492,7 +488,7 @@ def pack( def unpack( packed_input: TensorLike, axes: int | Sequence[int] | None, - packed_shapes: list[ShapeValueType], + packed_shapes: Sequence[ShapeValueType], ) -> list[TensorVariable]: """ Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. diff --git a/tests/tensor/test_reshape.py b/tests/tensor/test_reshape.py index c68eb7c31a..65da41c051 100644 --- a/tests/tensor/test_reshape.py +++ b/tests/tensor/test_reshape.py @@ -6,6 +6,7 @@ from pytensor import config, function from pytensor import tensor as pt from pytensor.graph import rewrite_graph, vectorize_graph +from pytensor.graph.op import io_connection_pattern from pytensor.tensor.reshape import ( _analyze_axes_list, join_dims, @@ -61,9 +62,10 @@ def test_join_dims(): [ (0, pt.as_tensor([2, 3]), (2, 3, 4, 6)), (2, [2, 3], (6, 4, 2, 3)), + (-1, pt.as_tensor(6), (6, 4, 6)), (-1, 6, (6, 4, 6)), ], - ids=["tensor", "list", "integer"], + ids=["tensor list", "integer list", "tensor", "integer"], ) def test_split_dims(axis, shape, expected_shape): rng = np.random.default_rng() @@ -95,7 +97,7 @@ def test_split_dims(axis, shape, expected_shape): def test_split_size_zero_shape(): x = pt.tensor("x", shape=(1, 4, 6)) - x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,)))) + x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,), dtype="int32"))) assert x_split.type.shape == (4, 6) x_value = np.empty((1, 4, 6), dtype=config.floatX) @@ -288,3 +290,12 @@ def test_pack_unpack_round_trip(self, axes): for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): np.testing.assert_allclose(input_val, output_val) + + +def test_unpack_connection(): + x = pt.vector("x") + d0 = pt.scalar("d0", dtype=int) + d1 = pt.scalar("d1", dtype=int) + x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1]) + out = x0.sum() + x1.sum() + assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]