Concretize the value of an array within a JIT context #7986
Replies: 1 comment 1 reply
-
Great question! There is a way, though it's currently an internal API and we need to make it public (in the sense of: currently you have to import it from Well, there's almost a way. Your exact code as written doesn't work because by the time you call from functools import partial
import jax
import jax.numpy as jnp
@jax.named_call # add this so it shows up in the jaxpr
def doublify(a):
result = 2 * a
return result
def get_array_of_twos():
a = jnp.ones(10)
return doublify(a)
print(jax.make_jaxpr(get_array_of_twos)())
The way to tell JAX, "make sure to evaluate this stuff at trace/compile time, or error if you can't!" is to use from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
@jax.named_call # add this so it shows up in the jaxpr
def doublify(a):
with jax.core.eval_context()
result = 2 * a
print(f'result is {result}')
return result
def get_array_of_twos():
with jax.core.eval_context():
a = jnp.ones(10)
return doublify(a)
get_array_of_twos()
Notice I had to use it twice: not only in Maybe you can try using |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a way to evaluate an expression within JIT context?
For example, in the snippet below could the variable
doublify.result
be computed eagerly to aDeviceArray
?Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions