diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index a6eafcf485..5444e764eb 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -28,7 +28,7 @@ from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import alloc, join, second +from pytensor.tensor.basic import alloc, arange, join, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -47,7 +47,7 @@ from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import Shape_i -from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor +from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable @@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0): return join(axis, *bcast_tensor_inputs) +def pack( + *tensors: TensorVariable, +) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: + """ + Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. + + Parameters + ---------- + tensors: TensorVariable + Tensors to be packed into a single vector. + + Returns + ------- + flat_tensor: TensorVariable + A new symbolic variable representing the concatenated 1d vector of all tensor inputs + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + """ + if not tensors: + raise ValueError("Cannot pack an empty list of tensors.") + + # Get the shapes of the input tensors + packed_shapes = [ + t.type.shape if not any(s is None for s in t.type.shape) else t.shape + for t in tensors + ] + + # Flatten each tensor and concatenate them into a single 1D vector + flat_tensor = join(0, *[t.ravel() for t in tensors]) + + return flat_tensor, packed_shapes + + +def unpack( + flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]] +) -> tuple[TensorVariable, ...]: + """ + Unpack a flat tensor into its original shapes based on the provided packed shapes. + + Parameters + ---------- + flat_tensor: TensorVariable + A 1D tensor that contains the concatenated values of the original tensors. + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + + Returns + ------- + unpacked_tensors: tuple of TensorVariable + A tuple containing the unpacked tensors with their original shapes. + """ + if not packed_shapes: + raise ValueError("Cannot unpack an empty list of shapes.") + + start = 0 + unpacked_tensors = [] + for shape in packed_shapes: + size = prod(shape, no_zeros_in_input=True) + end = start + size + unpacked_tensors.append( + take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) + ) + start = end + + return tuple(unpacked_tensors) + + __all__ = [ "searchsorted", "cumsum", @@ -2096,4 +2163,6 @@ def concat_with_broadcast(tensor_list, axis=0): "logspace", "linspace", "broadcast_arrays", + "pack", + "unpack", ] diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 8274ddbcea..64e53e19d1 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Constant, applys_between, equal_computations +from pytensor.graph.basic import Constant, Variable, applys_between, equal_computations from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert from pytensor.tensor import alloc @@ -37,11 +37,13 @@ diff, fill_diagonal, fill_diagonal_offset, + pack, ravel_multi_index, repeat, searchsorted, squeeze, to_one_hot, + unpack, unravel_index, ) from pytensor.tensor.type import ( @@ -1378,3 +1380,72 @@ def test_concat_with_broadcast(): a = pt.tensor("a", shape=(1, 3, 5)) b = pt.tensor("b", shape=(3, 5)) pt.concat_with_broadcast([a, b], axis=1) + + +@pytest.mark.parametrize( + "shapes, expected_flat_shape", + [([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)], + ids=["static", "symbolic"], +) +def test_pack(shapes, expected_flat_shape): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=shapes[0]) + y = pt.tensor("y", shape=shapes[1]) + z = pt.tensor("z", shape=shapes[2]) + + has_static_shape = [not any(s is None for s in shape) for shape in shapes] + + flat_packed, packed_shapes = pack(x, y, z) + + assert flat_packed.type.shape[0] == expected_flat_shape + + for i, (packed_shape, has_static) in enumerate( + zip(packed_shapes, has_static_shape) + ): + if has_static: + assert packed_shape == shapes[i] + else: + assert isinstance(packed_shape, Variable) + + new_outputs = unpack(flat_packed, packed_shapes) + + assert len(new_outputs) == 3 + assert all( + out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z]) + ) + + fn = function([x, y, z], new_outputs, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=shape).astype(config.floatX) + for var, shape in zip([x, y, z], [(), (5,), (3, 3)]) + ] + new_output_vals = fn(*input_vals) + for input, output in zip(input_vals, new_output_vals): + np.testing.assert_allclose(input, output) + + +def test_make_replacements_with_pack_unpack(): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + loss = (x + y.sum() + z.sum()) ** 2 + + flat_packed, packed_shapes = pack(x, y, z) + new_input = flat_packed.type() + new_outputs = unpack(new_input, packed_shapes) + + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) + fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] + ] + flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) + output_val = fn(flat_inputs) + + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)