Skip to content

Support stochastic iterable functions (and other minor suggestions) #15

@NeilGirdhar

Description

@NeilGirdhar

Feel free to close this issue. It's more of a set of suggestions than it is an actual issue. With Clément's help and using fax as a base, I ended up reimplementing the two-phase fixed point solver. I would prefer to use fax instead, but I can't do that unless fax has a few features:

  • The ability to work with stochastic iterable functions. After a lot of discussion with Clément, that just means that you have two functions: one that is stochastic and is used to find the fixed point, and another that is the expected value of the first function, and is used in the vjp.
  • The ability to work with arbitrary PyTree-like objects for both the state, and the parameters. That's an easy fix: it means using tree_map and tree_reduce in a few places. (e.g., dout = fp_vjp_fn(dout)[0] + dvalue needs to map using jnp.add.)
  • The object that defines the iteration should be pytree-like (I think that's fax's params_func). If it's not, then it needs to at least be hashable so that when it's passed as a static argument to a jitted function, it will always induce recompilation.
  • The ability to calculate a trajectory with the same inputs that are used to calculate a fixed point.

Just as an example (you're welcome to use any code that's useful to you), I ended up implementing a version of your two-phase solver with the above features.

The user-facing code is here.The combinator is here. I borrowed your tests.
And here I show how it's used to write pretty code in an object-oriented setting.

I did use classes, but after all, the param_func in fax is essentially an object. Unfortunately, a lot of the examples use closures rather than classes, which means that the closed over values are practically inaccessible, which complicates debugging.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions