create a concrete value during the tracing #7767
-
Hi, Is there a way to create concrete arrays while being in a tracer call? @jit
def f():
x = jnp.array([1.0, 2.0])
# I want a concrete x because I want to use it only during the tracing Here is why I need that. Let say I have two cases whether a function is symmetric or antisymmetric and I have a function to check that numerically: def parity(f):
a = jnp.linspace(-1.0, 1.0, 21)
if jnp.allclose(f(a), f(-a)):
return 'symmetric'
if jnp.allclose(f(a), -f(-a)):
return 'antisymmetric'
return 'undefined' The problem is that if I use
How to force |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
I dream of something like a context manager def parity(f):
with jax.concrete():
a = jnp.linspace(-1.0, 1.0, 21)
if jnp.allclose(f(a), f(-a)):
return 'symmetric'
if jnp.allclose(f(a), -f(-a)):
return 'antisymmetric'
return 'undefined' |
Beta Was this translation helpful? Give feedback.
-
The easiest way to make a value (or an operation) concrete, is to use the def parity(f):
a = np.linspace(-1.0, 1.0, 21)
if np.allclose(f(a), f(-a)):
return 'symmetric'
if np.allclose(f(a), -f(-a)):
return 'antisymmetric'
return 'undefined' |
Beta Was this translation helpful? Give feedback.
The easiest way to make a value (or an operation) concrete, is to use the
np
version rather than thejnp
version: