How to convert a ShapedArray to scalar when it is known that the array is 0-dim? #7339
-
Consider the function below: from jax import jit
import jax.numpy as jnp
def shrink(x, threshold=1.):
z = jnp.zeros_like(x)
return jnp.maximum(z, x - threshold)
shrink_jit = jit(shrink, static_argnums=(1,)) It has one static argument. Works great with simple usage: x = jnp.array([1., 0.5, 0.4, 3.6, -3, 1.5])
print(shrink(x))
print(shrink_jit(x)) I wish to use it in a function where the threshold is computed dynamically through other JAX array operations. The following works: def f1():
threshold = jnp.sqrt(jnp.vdot(x, x)) / 5
return shrink(x, threshold)
print(f1())
print(jit(f1)()) However, the following doesn't work. def f2():
threshold = jnp.sqrt(jnp.vdot(x, x))
return shrink_jit(x, threshold)
# TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
print(f2()) I could make this work by using the def f3():
threshold = jnp.sqrt(jnp.vdot(x, x))
return shrink_jit(x, threshold.item())
print(f3()) But this creates a different problem if I try to jit compile f3: # AttributeError: 'ShapedArray' object has no attribute 'item'
print(jit(f3)()) It looks like |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
During tracing Jax evaluates the function without using concrete values for the arguments (unless they are indicated as static). Jax operations inside the function also do not return concrete values. They instead return tracer objects of types like Without a concrete value, Fundamentally this can be resolved only if In conclusion, when it comes to jit interacting with static arguments of called functions, I see two cases. The first case is the potential static argument is dependent on the function input values. In this case static argument is not very useful and probably it should be made non-static. The second case is that the potential static argument is independent of inputs. A solution here is to evaluate this argument in a seperate function (that could be jitted) and then pass it as a static argument to the other function. |
Beta Was this translation helpful? Give feedback.
During tracing Jax evaluates the function without using concrete values for the arguments (unless they are indicated as static). Jax operations inside the function also do not return concrete values. They instead return tracer objects of types like
DynamicJaxprTracer
that contain useful information for tracing. Even if anitem
equivalent method were to be implemented, its return would still be a tracer object with no concrete value associated with it.Without a concrete value,
threhsold
is not very useful as a static argument. Static arguments are supposed to have concrete values so they can be used for doing additional optimizations. If the value of a static argument is not known then th…