-
Notifications
You must be signed in to change notification settings - Fork 159
Closed
Labels
beginner friendlybugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededmaintenance
Description
Description
Some things that have come up in usage:
axesshould be calledkeep_axesoraxes_to_keep, to make it clear what the argument actually does- The
shapesargument should come before theaxesargument, to match the numpy convention of putting axis last - When
len(packed_shapes) == 1, it currently errors out:
import pytensor.tensor as pt
x = pt.dvector('x')
x_packed, packed_shapes = pt.pack(x)
# len(packed_shapes) == 1
pt.unpack(x_packed, axes=None, packed_shapes=packed_shapes)Traceback
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 1
----> 1 pt.unpack(x_packed, None, packed_shapes)
File ~/Documents/Python/pytensor/pytensor/tensor/reshape.py:545, in unpack(packed_input, axes, packed_shapes)
532 raise ValueError(
533 "Unpack must have exactly one more dimension that implied by axes"
534 ) from err
536 split_inputs = split(
537 packed_input,
538 splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
539 n_splits=len(packed_shapes),
540 axis=split_axis,
541 )
543 return [
544 split_dims(inp, shape, split_axis)
--> 545 for inp, shape in zip(split_inputs, packed_shapes, strict=True)
546 ]
File ~/Documents/Python/pytensor/pytensor/tensor/variable.py:611, in _tensor_py_operators.__iter__(self)
609 def __iter__(self):
610 try:
--> 611 for i in range(pt.basic.get_vector_length(self)):
612 yield self[i]
613 except TypeError:
614 # This prevents accidental iteration via sum(self)
File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:88, in get_vector_length(v)
85 if static_shape is not None:
86 return static_shape
---> 88 return _get_vector_length(getattr(v.owner, "op", v), v)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:94, in _get_vector_length(op, var)
91 @singledispatch
92 def _get_vector_length(op: Op | Variable, var: Variable) -> int:
93 """`Op`-based dispatch for `get_vector_length`."""
---> 94 raise ValueError(f"Length of {var} cannot be determined")
ValueError: Length of Split{1}.0 cannot be determined
I'll add more as they come up.
Metadata
Metadata
Assignees
Labels
beginner friendlybugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededmaintenance