-
Notifications
You must be signed in to change notification settings - Fork 139
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
36422cd
to
cf633e7
Compare
The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the |
cf633e7
to
da89b9d
Compare
da89b9d
to
9ead211
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (78.94%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1578 +/- ##
==========================================
- Coverage 81.54% 81.54% -0.01%
==========================================
Files 230 230
Lines 53136 53153 +17
Branches 9448 9451 +3
==========================================
+ Hits 43329 43342 +13
- Misses 7370 7372 +2
- Partials 2437 2439 +2
🚀 New features to boost your workflow:
|
2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims |
for shape in packed_shapes: | ||
size = prod(shape, no_zeros_in_input=True) | ||
end = start + size | ||
unpacked_tensors.append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take uses advanced indexing. Actually the best here is probably split. Join and split are the inverses of each other, and will be easier to rewrite away
start = 0 | ||
unpacked_tensors = [] | ||
for shape in packed_shapes: | ||
size = prod(shape, no_zeros_in_input=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no zeros in input? The shape doesn't show up in gradients if that's what you were worried about
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX needs it as well iirc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why you can't have zeros in the shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok ok ok i'll fix it
@@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0): | |||
return join(axis, *bcast_tensor_inputs) | |||
|
|||
|
|||
def pack( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have a docstring (doctested) example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i will of course add docstrings. This was just to get a PR on the board and see your reaction to the 3 issues I raised. I didn't want to document everything before we decided on the final API
I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not. That would pretty much address #883 We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned. |
Description
Adds pt.pack and pt.unpack helpers, roughly conforming to the
einops
functions of the same name.These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.
Example usage:
Unpack simply undoes the computation, although there's norewrite to ensure
pt.unpack(*pt.pack(*inputs))
is the identity function:The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:
Note that the final compiled function depends only on
new_input
, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:pack
will eagerly return a list of integer shapes aspacked_shapes
if possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.pt.vectorize
.einops
API has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself usingDimShuffle
andvectorize
.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/