-
Notifications
You must be signed in to change notification settings - Fork 149
Add squeeze for labeled tensors #1434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add squeeze for labeled tensors #1434
Conversation
… ExpandDims op and rewrite rule to not add a new dimension when dim is None - Update tests to verify behavior matches xarray
…and streamlining logic.
|
@ricardoV94 Please take a look at squeeze. expand_dims is still WIP |
pytensor/xtensor/shape.py
Outdated
| XTensorVariable | ||
| A new tensor with the specified dimension removed | ||
| """ | ||
| return Squeeze(dim=dim)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better not to have None in the Op. Do the conversion here and pass explicit dims to the Op. The reason for this has to do with PyTensor constraints.
Our Squeeze Op should always know which explicit dims are do be dropped, because the input could change subtly during rewrites, and now we find out a dimension has length 1 after all, which we didn't know before, and reapplying the same Op will change the output type, which is not allowed during rewrites.
Another note, xarray squeeze seems to accept axis argument to do positional squeeze, we should allow that and convert to dims: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.squeeze.html#xarray-dataarray-squeeze
Better to always check the docs of the xarray method we're trying to emulate to be aware of special arguments
You may need to experiment a bit about what does xarray do if you specify both, or specify invalid dims/axis, to try and emulate the behavior on our side as much as is reasonable for us to do.
|
@ricardoV94 I have restored the version with tests that validate against xarray behavior. I think squeeze is ready for review. Again, ignore expand_dims for now. I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct? |
pytensor/xtensor/rewriting/shape.py
Outdated
| if not isinstance(node.op, ExpandDims): | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check isn't needed, the node_rewriter argument is already used to preselect such nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll not that for the next iteration, but expand_dims is not ready for review.
pytensor/xtensor/rewriting/shape.py
Outdated
| # If dim is None, don't add a new dimension (matching xarray behavior) | ||
| if dim is None: | ||
| return [x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to support this at the Op level, just make it return self when x.expand_dims(None) is called if we want to even support that
pytensor/xtensor/rewriting/shape.py
Outdated
| return [x] | ||
|
|
||
| # Create new dimensions list with the new dimension at the beginning | ||
| new_dims = [dim, *list(x.type.dims)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should support multiple expand_dims, not only one?
pytensor/xtensor/rewriting/shape.py
Outdated
|
|
||
| x = node.inputs[0] | ||
| dim = node.op.dim | ||
| size = getattr(node.op, "size", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size should be a symbolic input (or multiple if we have multiple dims) to the node, so you'll have x, *sizes = node.inputs. This way they can be arbitrary symbolic expressions and not just constants. Check how unstack does it.
pytensor/xtensor/rewriting/shape.py
Outdated
| if not isinstance(node.op, Squeeze): | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed
pytensor/xtensor/rewriting/shape.py
Outdated
| if not isinstance(node.op, Squeeze): | ||
| return False | ||
|
|
||
| x = node.inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, I like to do [x] = node.inputs to be explicit that this is a single input node
pytensor/xtensor/rewriting/shape.py
Outdated
| dim = node.op.dim | ||
|
|
||
| # Convert single dimension to iterable for consistent handling | ||
| dims_to_remove = [dim] if isinstance(dim, str) else dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sort of normalization should be done at the time the Op/node is defined. The earlier we normalize stuff the easier it is to work downstream.
pytensor/xtensor/rewriting/shape.py
Outdated
| else: | ||
| # Find all dimensions of size 1 | ||
| dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] | ||
| if not dim_indices: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't happen at rewrite time. Decide at the time you create the Op/node what dimensions will be dropped and stick to those. This is a case where PyTensor deviates from numpy/xarray, due to it's non-eager nature and the ability to work with unknown shapes.
You can see this happening in pytensor like this:
import pytensor.tensor as pt
import numpy as np
x = pt.tensor("x", shape=(None, 2, 1, 2, None))
y = x.squeeze()
assert y.eval({x: np.zeros((1, 2, 1, 2, 1))}).shape == (1, 2, 2, 1)Only the dimension we knew to be length 1 when x.squeeze() was called was dropped. We never try to update which dimension we drop, because y is bound to it's type y.type, that cannot change during rewrites (well shape can go from None -> int), but ndim cannot change.
pytensor/xtensor/rewriting/shape.py
Outdated
| return False | ||
|
|
||
| # Create new dimensions list | ||
| new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reuse node.outputs[0].type.dims since you already did the work of figuring out the output dims in make_node
I replied in the comment. It's not correct. Symbolic inputs have to show up in import pytensor
import pytensor.xtensor as px
x = px.xtensor("x", shape=(2,), dims=("a",))
b_size = px.xtensor("b_size", shape=(), dims=())
y = y.expand_dims(b=b_size)
y.eval({x: np.array([0, 1]), b_size: np.array(10)})If you try this right now you may get an error or a silent bug, because |
|
@ricardoV94 squeeze is ready for another look. expand_dims is still not ready for review |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did another pass on the Squeeze functionality.
I suggested removing the None case from the Op level and left some other minor comments about it.
I think the tests right now are a bit overkill / redundant / messy. I suggest grouping in different functions the following things:
- Tests with explicit squeeze dim (single, multiple, order independent)
- Tests with implicit None dim (including the case that at runtime deviates from xarray and as documented)
- Tests for errors raised by the Op at creation or runtime
pytensor/xtensor/shape.py
Outdated
| to be size 1 at runtime. | ||
| """ | ||
|
|
||
| __props__ = ("dim",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: call the Op prop dims instead of dim (still use dim in the user facing functions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, but it means that dim and dims are all over the place now. Worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean they are all over the place now? Why is that?
|
@ricardoV94 squeeze is ready for another look, and expand_dims is ready, too |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the Squeeze changes, mostly nitpicks and code simplification at this point. I suggest we split the ExpandDims, because it's not yet allowing symbolic shapes AFAICT
| # Order independence | ||
| x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) | ||
| y3a = squeeze(x3, ["b", "c"]) | ||
| y3b = squeeze(x3, ["c", "b"]) | ||
| fn3a = xr_function([x3], y3a) | ||
| fn3b = xr_function([x3], y3b) | ||
| x3_test = xr_arange_like(x3) | ||
| xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine this with the previous test. Test both of them against xarray. You don't need one function per case, the function can have two outputs, which should be a faster test, as it only trigger the compilation machinery once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have too many questions about this comment. If you want to make this change after merging, that might be more efficient than explaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your previous check was already testing a squeeze of multiple dimensions, so you can combine this which also checks multiple dimensions + the fact that order doesn't matter. This test is a superset of the previous one.
Then the point about combining multiple outputs is to do xr_function([x3], [y3a, y3b]) instead of defining two separate functions. It's a small optimization question, although it has the side-benefit of testing multiple outputs aren't messed up either.
Co-authored-by: Ricardo Vieira <[email protected]>
|
@ricardoV94 I've removed expand_dims from this PR, and I think I've addressed all comments on squeeze except the one I noted about combining tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last round of nits, some are things we already discussed but were reverted / ignored. The only critical one is the pytest.importorskip must be done before any xarray import, otherwise the CI will fail for the jobs where we don't install xarray.
And the name of dim in Squeeze.__init__
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
|
@ricardoV94 Ready for another look |
|
Thanks for persisting with all my nits! |
|
Of course. Hopefully it means fewer iterations on the next PR. |
Adding squeeze and expand_dims
📚 Documentation preview 📚: https://pytensor--1434.org.readthedocs.build/en/1434/