Pulling an intermediate variable outside a function #16572
Unanswered
astanziola
asked this question in
Ideas
Replies: 1 comment 1 reply
-
I believe Oryx implements this idea (see here), but when I tried to install it with pip it looks like it may not be compatible with the most recent JAX version. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
One feature that would be quite nice to have is the ability to retrieve an intermediate variable from a jitted function, along with its results, without needing to modify the function signature. This could be especially useful for nested functions.
My thoughts are along these lines: take, for instance, the functions
I would like to have some method, perhaps called
pull_variable
, that could extract anamed_variable
like so:Now, I would be able to execute:
Where
variables = {"qux": y}
.While I don't see a reason why this shouldn't be feasible, I'm curious if there's a way to achieve this in Jax currently?
If not, this could indeed be a great feature, and I can foresee quite a few potential use cases for it, such as:
On a tangential note, there may be some complimentary arguments for having a
push_variable
decorator that modifies the input signature of a function, and allows injecting a value in place of a variable in a computational graph, but that's partially beyond the scope of this question 😄Beta Was this translation helpful? Give feedback.
All reactions