Why aren't jitted function a JAX object? #6664
Unanswered
AdrienCorenflos
asked this question in
Ideas
Replies: 1 comment 3 replies
-
Does this do what you want? from functools import partial
from jax import jit
@jit
def f(x):
return x
@partial(jit, static_argnums=0)
def g(h, y):
return h(y)
g(f, 5.) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
There are a number of cases when I have wanted to pass a jitted function as an argument to another one, but this is not possible (yet), contrarily to say tensorflow. Des this have something to do with possible closure problems/gradient tracing?
Would it possibly considered as a future improvement or is it a no go?
To fix ideas, I'm talking of doing something like this:
Beta Was this translation helpful? Give feedback.
All reactions