Skip to content

Implement pack/unpack helpers #1578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.basic import alloc, arange, join, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand All @@ -47,7 +47,7 @@
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0):
return join(axis, *bcast_tensor_inputs)


def pack(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a docstring (doctested) example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i will of course add docstrings. This was just to get a PR on the board and see your reaction to the 3 issues I raised. I didn't want to document everything before we decided on the final API

*tensors: TensorVariable,
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
"""
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.

Parameters
----------
tensors: TensorVariable
Tensors to be packed into a single vector.

Returns
-------
flat_tensor: TensorVariable
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
packed_shapes: list of tuples of TensorVariable
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
"""
if not tensors:
raise ValueError("Cannot pack an empty list of tensors.")

# Get the shapes of the input tensors
packed_shapes = [
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
for t in tensors
]

# Flatten each tensor and concatenate them into a single 1D vector
flat_tensor = join(0, *[t.ravel() for t in tensors])

return flat_tensor, packed_shapes


def unpack(
flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]]
) -> tuple[TensorVariable, ...]:
"""
Unpack a flat tensor into its original shapes based on the provided packed shapes.

Parameters
----------
flat_tensor: TensorVariable
A 1D tensor that contains the concatenated values of the original tensors.
packed_shapes: list of tuples of TensorVariable
A list of tuples, where each tuple contains the symbolic shape of the original tensors.

Returns
-------
unpacked_tensors: tuple of TensorVariable
A tuple containing the unpacked tensors with their original shapes.
"""
if not packed_shapes:
raise ValueError("Cannot unpack an empty list of shapes.")

start = 0
unpacked_tensors = []
for shape in packed_shapes:
size = prod(shape, no_zeros_in_input=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no zeros in input? The shape doesn't show up in gradients if that's what you were worried about

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX needs it as well iirc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why you can't have zeros in the shapes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok ok ok i'll fix it

end = start + size
unpacked_tensors.append(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take uses advanced indexing. Actually the best here is probably split. Join and split are the inverses of each other, and will be easier to rewrite away

take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape)
)
start = end

return tuple(unpacked_tensors)


__all__ = [
"searchsorted",
"cumsum",
Expand All @@ -2096,4 +2163,6 @@ def concat_with_broadcast(tensor_list, axis=0):
"logspace",
"linspace",
"broadcast_arrays",
"pack",
"unpack",
]
73 changes: 72 additions & 1 deletion tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Constant, applys_between, equal_computations
from pytensor.graph.basic import Constant, Variable, applys_between, equal_computations
from pytensor.npy_2_compat import old_np_unique
from pytensor.raise_op import Assert
from pytensor.tensor import alloc
Expand Down Expand Up @@ -37,11 +37,13 @@
diff,
fill_diagonal,
fill_diagonal_offset,
pack,
ravel_multi_index,
repeat,
searchsorted,
squeeze,
to_one_hot,
unpack,
unravel_index,
)
from pytensor.tensor.type import (
Expand Down Expand Up @@ -1378,3 +1380,72 @@ def test_concat_with_broadcast():
a = pt.tensor("a", shape=(1, 3, 5))
b = pt.tensor("b", shape=(3, 5))
pt.concat_with_broadcast([a, b], axis=1)


@pytest.mark.parametrize(
"shapes, expected_flat_shape",
[([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)],
ids=["static", "symbolic"],
)
def test_pack(shapes, expected_flat_shape):
rng = np.random.default_rng()

x = pt.tensor("x", shape=shapes[0])
y = pt.tensor("y", shape=shapes[1])
z = pt.tensor("z", shape=shapes[2])

has_static_shape = [not any(s is None for s in shape) for shape in shapes]

flat_packed, packed_shapes = pack(x, y, z)

assert flat_packed.type.shape[0] == expected_flat_shape

for i, (packed_shape, has_static) in enumerate(
zip(packed_shapes, has_static_shape)
):
if has_static:
assert packed_shape == shapes[i]
else:
assert isinstance(packed_shape, Variable)

new_outputs = unpack(flat_packed, packed_shapes)

assert len(new_outputs) == 3
assert all(
out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z])
)

fn = function([x, y, z], new_outputs, mode="FAST_COMPILE")

input_vals = [
rng.normal(size=shape).astype(config.floatX)
for var, shape in zip([x, y, z], [(), (5,), (3, 3)])
]
new_output_vals = fn(*input_vals)
for input, output in zip(input_vals, new_output_vals):
np.testing.assert_allclose(input, output)


def test_make_replacements_with_pack_unpack():
rng = np.random.default_rng()

x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))

loss = (x + y.sum() + z.sum()) ** 2

flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")

input_vals = [
rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z]
]
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
output_val = fn(flat_inputs)

assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)