-
Notifications
You must be signed in to change notification settings - Fork 139
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
from pytensor.scalar import upcast | ||
from pytensor.tensor import TensorLike, as_tensor_variable | ||
from pytensor.tensor import basic as ptb | ||
from pytensor.tensor.basic import alloc, join, second | ||
from pytensor.tensor.basic import alloc, arange, join, second | ||
from pytensor.tensor.exceptions import NotScalarConstantError | ||
from pytensor.tensor.math import abs as pt_abs | ||
from pytensor.tensor.math import all as pt_all | ||
|
@@ -47,7 +47,7 @@ | |
from pytensor.tensor.math import max as pt_max | ||
from pytensor.tensor.math import sum as pt_sum | ||
from pytensor.tensor.shape import Shape_i | ||
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor | ||
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take | ||
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector | ||
from pytensor.tensor.utils import normalize_reduce_axis | ||
from pytensor.tensor.variable import TensorVariable | ||
|
@@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0): | |
return join(axis, *bcast_tensor_inputs) | ||
|
||
|
||
def pack( | ||
*tensors: TensorVariable, | ||
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: | ||
""" | ||
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. | ||
|
||
Parameters | ||
---------- | ||
tensors: TensorVariable | ||
Tensors to be packed into a single vector. | ||
|
||
Returns | ||
------- | ||
flat_tensor: TensorVariable | ||
A new symbolic variable representing the concatenated 1d vector of all tensor inputs | ||
packed_shapes: list of tuples of TensorVariable | ||
A list of tuples, where each tuple contains the symbolic shape of the original tensors. | ||
""" | ||
if not tensors: | ||
raise ValueError("Cannot pack an empty list of tensors.") | ||
|
||
# Get the shapes of the input tensors | ||
packed_shapes = [ | ||
t.type.shape if not any(s is None for s in t.type.shape) else t.shape | ||
for t in tensors | ||
] | ||
|
||
# Flatten each tensor and concatenate them into a single 1D vector | ||
flat_tensor = join(0, *[t.ravel() for t in tensors]) | ||
|
||
return flat_tensor, packed_shapes | ||
|
||
|
||
def unpack( | ||
flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]] | ||
) -> tuple[TensorVariable, ...]: | ||
""" | ||
Unpack a flat tensor into its original shapes based on the provided packed shapes. | ||
|
||
Parameters | ||
---------- | ||
flat_tensor: TensorVariable | ||
A 1D tensor that contains the concatenated values of the original tensors. | ||
packed_shapes: list of tuples of TensorVariable | ||
A list of tuples, where each tuple contains the symbolic shape of the original tensors. | ||
|
||
Returns | ||
------- | ||
unpacked_tensors: tuple of TensorVariable | ||
A tuple containing the unpacked tensors with their original shapes. | ||
""" | ||
if not packed_shapes: | ||
raise ValueError("Cannot unpack an empty list of shapes.") | ||
|
||
start = 0 | ||
unpacked_tensors = [] | ||
for shape in packed_shapes: | ||
size = prod(shape, no_zeros_in_input=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why no zeros in input? The shape doesn't show up in gradients if that's what you were worried about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JAX needs it as well iirc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why you can't have zeros in the shapes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok ok ok i'll fix it |
||
end = start + size | ||
unpacked_tensors.append( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Take uses advanced indexing. Actually the best here is probably split. Join and split are the inverses of each other, and will be easier to rewrite away |
||
take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) | ||
) | ||
start = end | ||
|
||
return tuple(unpacked_tensors) | ||
|
||
|
||
__all__ = [ | ||
"searchsorted", | ||
"cumsum", | ||
|
@@ -2096,4 +2163,6 @@ def concat_with_broadcast(tensor_list, axis=0): | |
"logspace", | ||
"linspace", | ||
"broadcast_arrays", | ||
"pack", | ||
"unpack", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have a docstring (doctested) example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i will of course add docstrings. This was just to get a PR on the board and see your reaction to the 3 issues I raised. I didn't want to document everything before we decided on the final API