-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Feel free to close this issue. It's more of a set of suggestions than it is an actual issue. With Clément's help and using fax as a base, I ended up reimplementing the two-phase fixed point solver. I would prefer to use fax instead, but I can't do that unless fax has a few features:
- The ability to work with stochastic iterable functions. After a lot of discussion with Clément, that just means that you have two functions: one that is stochastic and is used to find the fixed point, and another that is the expected value of the first function, and is used in the
vjp. - The ability to work with arbitrary PyTree-like objects for both the state, and the parameters. That's an easy fix: it means using
tree_mapandtree_reducein a few places. (e.g.,dout = fp_vjp_fn(dout)[0] + dvalueneeds to map usingjnp.add.) - The object that defines the iteration should be pytree-like (I think that's fax's
params_func). If it's not, then it needs to at least be hashable so that when it's passed as a static argument to a jitted function, it will always induce recompilation. - The ability to calculate a trajectory with the same inputs that are used to calculate a fixed point.
Just as an example (you're welcome to use any code that's useful to you), I ended up implementing a version of your two-phase solver with the above features.
The user-facing code is here.The combinator is here. I borrowed your tests.
And here I show how it's used to write pretty code in an object-oriented setting.
I did use classes, but after all, the param_func in fax is essentially an object. Unfortunately, a lot of the examples use closures rather than classes, which means that the closed over values are practically inaccessible, which complicates debugging.