-
Notifications
You must be signed in to change notification settings - Fork 146
Implement new Loop and Scan operators #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| assert input_state.type == output_state.type | ||
|
|
||
|
|
||
| class Loop(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add mixin HasInnerGraph so that we can see the inner graph in debug_print
76a9b4c to
f2a2c03
Compare
Wouldn't a fill loop look something like this? (and very much need inplace rewrites for good performance...)
Good question...
I think one rewrite that get's easier with the if-else-do-while approach would be loop invariant code motion. Let's say we have a loop like we could move
Well, I guess we really need those :-) |
Why can't we move it even if it's empty? Sum works fine. Are you worried about Ops that we know will fail with empty inputs? About the filling Ops, yeah I don't see it as a problem anymore. Just felt awkward to create the dummy input when translating from scan to loop. I am okay with it now |
|
That would change the behavior. If we move it out and don't prevent it from being executed, things could fail for instance if there's an assert somewhere, or some other error happens during it's evaluation. Also, it could be potentially very costly (let's say "solve an ode"). (somehow I accidentally edited your comment instead of writing a new one, no clue how, but fixed now) |
|
In my last commit, sequences are demoted from special citizens to just another constant input in the I have reverted converting the constant inputs to dummies before calling the user function, which allows the example in the jacobian documentation to work, including the one that didn't work before (because both are now equivalent under the hood :)) https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html#computing-the-jacobian I reverted too much, and I still need to pass dummy inputs as the state variables, since it doesn't make sense for the user function to introspect the graph beyond the initial state (since it's only valid for the initial state) |
7bcd42c to
6c953b3
Compare
| return last_states[1:], traces[1:] | ||
|
|
||
|
|
||
| def map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about subclassing Scan into
Map(Scan)Reduce(Scan)Filter(Scan)
It will be easier to dispatch into optimized implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that later, not convinced we need that yet
| if init_state is None: | ||
| # next_state may reference idx. We replace that by the initial value, | ||
| # so that the shape of the dummy init state does not depend on it. | ||
| [next_state] = clone_replace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not graph_replace or using memo for FunctionGraph(memo={symbolic_idx: idx}) (here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that better?
|
Added a simple JAX dispatcher, works in the few examples I tried |
| # explicitly triggers the optimization of the inner graphs of Scan? | ||
| update_fg = op.update_fg.clone() | ||
| rewriter = get_mode("JAX").optimizer | ||
| rewriter(update_fg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives an annoying Supervisor Feature missing warning... gotta clean that up
|
|
||
| print(max_iters) | ||
| states, traces = jax.lax.scan( | ||
| scan_fn, init=list(states), xs=None, length=max_iters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Todo: Check we are not missing performance by not having explicit sequences.
Todo: When there are multiple sequences PyTensor defines n_steps as the shortest sequence. JAX should be able to handle this, but if not we could consider not allowing sequences/n_steps with different lengths in the Pytensor scan.
Then we could pass a single shape as n_steps after asserting they are the same?
|
I just found out about TypedLists in PyTensor. That should allow us to trace any type of Variables, including RandomTypes 🤯 Pushed a couple of commits that rely on this. |
5f15c5e to
32b4fb4
Compare
Co-authored-by: Adrian Seyboldt <[email protected]>
Co-authored-by: Adrian Seyboldt <[email protected]>
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #191 +/- ##
==========================================
+ Coverage 80.03% 80.09% +0.06%
==========================================
Files 170 173 +3
Lines 45086 45435 +349
Branches 9603 9694 +91
==========================================
+ Hits 36085 36392 +307
- Misses 6789 6818 +29
- Partials 2212 2225 +13
|
|
This Discourse thread is a great reminder of several Scan design issues that are fixed here: https://discourse.pymc.io/t/hitting-a-weird-error-to-do-with-rngs-in-scan-in-a-custom-function-inside-a-potential/13151/15 Namely:
|
Related to #189
This PR implements a new low level
LoopOpwhich can be easily transpiled toNumba(the Python perform method takes 9 lines, yay to not having to supportCin the future).It also implements a new higher level
ScanOpwhich returns as outputs the last states + intermediate states of a looping operation. ThisOpcannot be directly evaluated, and must be rewritten as aLoopOpin Python/Numba backends. For theJAXbackend it's probably fine to transpile directly from this representation into alax.scanas the signatures are pretty much identical. That was not done in this PR.The reason for the two types of outputs, is that they are useful in different contexts. Final states are sometimes all one needs, whereas intermediate states are generally needed for back propagation (not implemented yet). This allows us to choose which one (or both) of the outputs we want during compilation, without having to do complicated graph analysis.
The existing
save_mem_new_scanis used to convert a general scan into aloopthat only returns the last computed state. It's... pretty complicated (although it also covers cases where more than 1 but less than all steps being requested, but OTOH it can't handle while loops #178):pytensor/pytensor/scan/rewriting.py
Line 1119 in 8ad3317
Taking that as a reference I would say the new conversion rewrite from Scan to Loop is much much simpler. Most of it is boilerplate code for defining the right trace inputs and new FunctionGraph
Both
Opsexpect aFunctionGraphas input. This should probably be created by a user-facing helper that accepts a callable like scan does now.That was not done yet, as I first wanted to discuss the general design.DoneDesign issues
1. The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummyxinput to accommodate this restriction. Should we useNoneConstto represent outputs that don't feed into the next state? I think there is something similar being done with the oldScanwhere theoutputs_infomust explicitly beNonein these cases.Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a
return_rng_updateto__call__, so that it doesn't hide the next rng state output?Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that?
IfElseis one option, but perhaps it would be nice to represent it in the sameLoopOp?What do we want to do in terms of inplacing optimizations?
TODO
If people are on board with the approach
mode,truncate_gradient,reverseand so on)trace[-1]by the first set of outputs (final state). That way we can keep the old API, while retaining the benefit of doing while Scans without tracing when it's not needed.