Replies: 1 comment 6 replies
-
The main downside here is that it would break some of the core assumptions of the APIs that use the pytree abstraction within JAX. Namely, the assumption that all relevant dynamic content is contained in the children, and that the The other thing I'd be worried about here would be leaked tracers. The approach is probably fine so long as pytree flattening/unflattening always happens in a paired fashion at a |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
in the past years I've been writing a lot of JAX code and I'm currently working on devolping a more object oriented interface to JAX (by basically having an abstraction layer that keep tracks of all the tensors through each jax transformation) and, among the many approaches I could think about, the following seems the easier. However, it goes against the core JAX idea that flattening/unflattening creates a copy of an object (i.e., a = unflatten(flatten(b)) => a is not b). So my question is, what are the problems that may arise with such approach?
In particular, you can look at the following (rather simplified) example:
Note how a.value is changed without having to reassign it.
I believe this approach, other than making everything stateful and object oriented, make also the overhead of jitting smaller as the flattening/unflattening operation of the function parameters and return values doesn't need to go through complex model, but just on the list of their parameters (as the model can simply be passed to the closure of the jitted function).
Of course if someone using this approach would want to actually create a copy of a Param using flattening/unflattening I could simply provide a function for it, so that's not a problem. I'm mostly interested in clashes with the internal functioning of jax and possible jitting optimizations that I may be missing.
Thank you for your time. I'm happy to have a deep discussion about this :)
Beta Was this translation helpful? Give feedback.
All reactions