-
import jax
import jax.numpy as jnp
@jax.jit
def compare_in_range(x,y):
return (x>=y[0]) & (x<y[1])
print('python bitwise and ',compare_in_range(1,(1,3)))
@jax.jit
def compare_in_range(x,y):
return jnp.bitwise_and( (x>=y[0]) , (x<y[1]))
print('jnp bitwise and ',compare_in_range(1,(1,3)))
@jax.jit
def compare_in_range(x,y):
return jnp.logical_and( (x>=y[0]) , (x<y[1]))
print('jnp logical and ',compare_in_range(1,(1,3)))
@jax.jit
def compare_in_range(x,y):
return (x>=y[0]) and (x<y[1])
print('python logical and' , compare_in_range(1,(1,3))) Returns python bitwise and True
jnp bitwise and True
jnp logical and True
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-15-7383f0047c04>](https://localhost:8080/#) in <module>()
27
---> 28 print('python logical and' , compare_in_range(1,(1,3)))
18 frames
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function compare_in_range at <ipython-input-15-7383f0047c04>:24 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'x' and 'y'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
[<ipython-input-15-7383f0047c04>](https://localhost:8080/#) in compare_in_range(x, y)
24 @jax.jit
25 def compare_in_range(x,y):
---> 26 return (x>=y[0]) and (x<y[1])
27
28 print('python logical and' , compare_in_range(1,(1,3)))
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function compare_in_range at <ipython-input-15-7383f0047c04>:24 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'x' and 'y'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError My question is, is this an intended behavior? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi, thanks for the question! This is expected behavior. Pythons For more information on JAX transforms, tracers, and concrete values, see https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question! This is expected behavior. Pythons
and
operator requires concrete boolean arguments, and thus is incompatible with traced values within JAX transforms, in which the inputs are not concrete. Moreover, unlike the&
operator, which can be overloaded via the__and__
method, Python offers no way to directly overload the behavior of itsand
operator. For these reasons, there is no way for JAX (and other array libraries likenumpy
) to dispatch element-wise array operations using Python'sand
. The alternative, which you've already discovered, is to use&
.For more information on JAX transforms, tracers, and concrete values, see https://jax.readthedocs.io/en/latest/not…