Skip to content

Fix unpack sharp edges #1835

@jessegrabowski

Description

@jessegrabowski

Description

Some things that have come up in usage:

  • axes should be called keep_axes or axes_to_keep, to make it clear what the argument actually does
  • The shapes argument should come before the axes argument, to match the numpy convention of putting axis last
  • When len(packed_shapes) == 1, it currently errors out:
import pytensor.tensor as pt
x = pt.dvector('x')
x_packed, packed_shapes = pt.pack(x)
# len(packed_shapes) == 1
pt.unpack(x_packed, axes=None, packed_shapes=packed_shapes)
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 pt.unpack(x_packed, None, packed_shapes)

File ~/Documents/Python/pytensor/pytensor/tensor/reshape.py:545, in unpack(packed_input, axes, packed_shapes)
    532         raise ValueError(
    533             "Unpack must have exactly one more dimension that implied by axes"
    534         ) from err
    536 split_inputs = split(
    537     packed_input,
    538     splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
    539     n_splits=len(packed_shapes),
    540     axis=split_axis,
    541 )
    543 return [
    544     split_dims(inp, shape, split_axis)
--> 545     for inp, shape in zip(split_inputs, packed_shapes, strict=True)
    546 ]

File ~/Documents/Python/pytensor/pytensor/tensor/variable.py:611, in _tensor_py_operators.__iter__(self)
    609 def __iter__(self):
    610     try:
--> 611         for i in range(pt.basic.get_vector_length(self)):
    612             yield self[i]
    613     except TypeError:
    614         # This prevents accidental iteration via sum(self)

File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:88, in get_vector_length(v)
     85 if static_shape is not None:
     86     return static_shape
---> 88 return _get_vector_length(getattr(v.owner, "op", v), v)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:94, in _get_vector_length(op, var)
     91 @singledispatch
     92 def _get_vector_length(op: Op | Variable, var: Variable) -> int:
     93     """`Op`-based dispatch for `get_vector_length`."""
---> 94     raise ValueError(f"Length of {var} cannot be determined")

ValueError: Length of Split{1}.0 cannot be determined

I'll add more as they come up.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions