Skip to content

Conversation

@AllenDowney
Copy link
Contributor

@AllenDowney AllenDowney commented May 30, 2025

Adding squeeze and expand_dims


📚 Documentation preview 📚: https://pytensor--1434.org.readthedocs.build/en/1434/

… ExpandDims op and rewrite rule to not add a new dimension when dim is None - Update tests to verify behavior matches xarray
@AllenDowney
Copy link
Contributor Author

@ricardoV94 Please take a look at squeeze. expand_dims is still WIP

XTensorVariable
A new tensor with the specified dimension removed
"""
return Squeeze(dim=dim)(x)
Copy link
Member

@ricardoV94 ricardoV94 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better not to have None in the Op. Do the conversion here and pass explicit dims to the Op. The reason for this has to do with PyTensor constraints.

Our Squeeze Op should always know which explicit dims are do be dropped, because the input could change subtly during rewrites, and now we find out a dimension has length 1 after all, which we didn't know before, and reapplying the same Op will change the output type, which is not allowed during rewrites.


Another note, xarray squeeze seems to accept axis argument to do positional squeeze, we should allow that and convert to dims: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.squeeze.html#xarray-dataarray-squeeze

Better to always check the docs of the xarray method we're trying to emulate to be aware of special arguments

You may need to experiment a bit about what does xarray do if you specify both, or specify invalid dims/axis, to try and emulate the behavior on our side as much as is reasonable for us to do.

@AllenDowney AllenDowney changed the title Add expand dims squeeze Add squeeze Jun 2, 2025
@AllenDowney
Copy link
Contributor Author

@ricardoV94 I have restored the version with tests that validate against xarray behavior. I think squeeze is ready for review. Again, ignore expand_dims for now.

I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct?

Comment on lines 128 to 129
if not isinstance(node.op, ExpandDims):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check isn't needed, the node_rewriter argument is already used to preselect such nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll not that for the next iteration, but expand_dims is not ready for review.

Comment on lines 135 to 137
# If dim is None, don't add a new dimension (matching xarray behavior)
if dim is None:
return [x]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to support this at the Op level, just make it return self when x.expand_dims(None) is called if we want to even support that

return [x]

# Create new dimensions list with the new dimension at the beginning
new_dims = [dim, *list(x.type.dims)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support multiple expand_dims, not only one?


x = node.inputs[0]
dim = node.op.dim
size = getattr(node.op, "size", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size should be a symbolic input (or multiple if we have multiple dims) to the node, so you'll have x, *sizes = node.inputs. This way they can be arbitrary symbolic expressions and not just constants. Check how unstack does it.

Comment on lines 158 to 159
if not isinstance(node.op, Squeeze):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed

if not isinstance(node.op, Squeeze):
return False

x = node.inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, I like to do [x] = node.inputs to be explicit that this is a single input node

Comment on lines 162 to 165
dim = node.op.dim

# Convert single dimension to iterable for consistent handling
dims_to_remove = [dim] if isinstance(dim, str) else dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort of normalization should be done at the time the Op/node is defined. The earlier we normalize stuff the easier it is to work downstream.

Comment on lines 178 to 182
else:
# Find all dimensions of size 1
dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1]
if not dim_indices:
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't happen at rewrite time. Decide at the time you create the Op/node what dimensions will be dropped and stick to those. This is a case where PyTensor deviates from numpy/xarray, due to it's non-eager nature and the ability to work with unknown shapes.

You can see this happening in pytensor like this:

import pytensor.tensor as pt
import numpy as np

x = pt.tensor("x", shape=(None, 2, 1, 2, None))
y = x.squeeze()
assert y.eval({x: np.zeros((1, 2, 1, 2, 1))}).shape  == (1, 2, 2, 1)

Only the dimension we knew to be length 1 when x.squeeze() was called was dropped. We never try to update which dimension we drop, because y is bound to it's type y.type, that cannot change during rewrites (well shape can go from None -> int), but ndim cannot change.

return False

# Create new dimensions list
new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reuse node.outputs[0].type.dims since you already did the work of figuring out the output dims in make_node

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 2, 2025

I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct?

I replied in the comment. It's not correct. Symbolic inputs have to show up in make_node not __init__. You can try to create such a graph like this (adapt it to a test format):

import pytensor
import pytensor.xtensor as px

x = px.xtensor("x", shape=(2,), dims=("a",))
b_size = px.xtensor("b_size", shape=(), dims=())
y = y.expand_dims(b=b_size)

y.eval({x: np.array([0, 1]), b_size: np.array(10)})

If you try this right now you may get an error or a silent bug, because b_size is not part of the symbolic graph of y as far as PyTensor can tell

@AllenDowney
Copy link
Contributor Author

@ricardoV94 squeeze is ready for another look.

expand_dims is still not ready for review

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did another pass on the Squeeze functionality.

I suggested removing the None case from the Op level and left some other minor comments about it.

I think the tests right now are a bit overkill / redundant / messy. I suggest grouping in different functions the following things:

  1. Tests with explicit squeeze dim (single, multiple, order independent)
  2. Tests with implicit None dim (including the case that at runtime deviates from xarray and as documented)
  3. Tests for errors raised by the Op at creation or runtime

to be size 1 at runtime.
"""

__props__ = ("dim",)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: call the Op prop dims instead of dim (still use dim in the user facing functions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but it means that dim and dims are all over the place now. Worth it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean they are all over the place now? Why is that?

@AllenDowney
Copy link
Contributor Author

@ricardoV94 squeeze is ready for another look, and expand_dims is ready, too

@twiecki twiecki changed the title Add squeeze Add squeeze for labeled tensors Jun 4, 2025
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the Squeeze changes, mostly nitpicks and code simplification at this point. I suggest we split the ExpandDims, because it's not yet allowing symbolic shapes AFAICT

Comment on lines +439 to +446
# Order independence
x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1))
y3a = squeeze(x3, ["b", "c"])
y3b = squeeze(x3, ["c", "b"])
fn3a = xr_function([x3], y3a)
fn3b = xr_function([x3], y3b)
x3_test = xr_arange_like(x3)
xr_assert_allclose(fn3a(x3_test), fn3b(x3_test))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine this with the previous test. Test both of them against xarray. You don't need one function per case, the function can have two outputs, which should be a faster test, as it only trigger the compilation machinery once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have too many questions about this comment. If you want to make this change after merging, that might be more efficient than explaining.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your previous check was already testing a squeeze of multiple dimensions, so you can combine this which also checks multiple dimensions + the fact that order doesn't matter. This test is a superset of the previous one.

Then the point about combining multiple outputs is to do xr_function([x3], [y3a, y3b]) instead of defining two separate functions. It's a small optimization question, although it has the side-benefit of testing multiple outputs aren't messed up either.

@AllenDowney
Copy link
Contributor Author

@ricardoV94 I've removed expand_dims from this PR, and I think I've addressed all comments on squeeze except the one I noted about combining tests.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last round of nits, some are things we already discussed but were reverted / ignored. The only critical one is the pytest.importorskip must be done before any xarray import, otherwise the CI will fail for the jobs where we don't install xarray.

And the name of dim in Squeeze.__init__

@AllenDowney
Copy link
Contributor Author

@ricardoV94 Ready for another look

@ricardoV94
Copy link
Member

Thanks for persisting with all my nits!

@AllenDowney
Copy link
Contributor Author

Of course. Hopefully it means fewer iterations on the next PR.

@ricardoV94 ricardoV94 merged commit 7b1009e into pymc-devs:labeled_tensors Jun 6, 2025
4 of 5 checks passed
@AllenDowney AllenDowney deleted the add_expand_dims_squeeze branch June 6, 2025 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants