Applying a scalar-input root-finding function to elements of an array - using vmap
passes the entire array, rather than each scalar value.
#18975
-
I'm currently doing some work with this JAX snippet, which is intended to map a 1D array element-by-element, and depending on whether the element satisfies a certain criterion function, it will run a root-finding function: rootFinder = lambda M: brentq(f, a=constant, b=M) # this line is simplified
mapFn = lambda M: jax.lax.cond(f(M), lambda: M, lambda: rootFinder(M))
return jax.vmap(mapFn)(arr) Both |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
Hi, According to JAX docs, when using As you mentioned, using loop can be a solution. Here is a quick comparison between using import jax
import jax.numpy as jnp
import jax.random as jrandom
# make a somewhat expensive function
def dummy_fn(m):
return jnp.sum(jnp.ones((1000, )))
# dummy predicate
def f(m):
return m > 0
map_fn = lambda m: jax.lax.cond(f(m), lambda : m, lambda : dummy_fn(m)) Let's say the 1D array is generated as # since this is randomly generated, there is about 50% the the predicate is True
m = jrandom.normal(key=jrandom.PRNGKey(0), shape=(1000,)) using %%timeit
jax.vmap(map_fn)(m)
# 6.62 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) Now, using # scan function for jax.lax.scan
def body_fn(current, x):
ret = map_fn(x)
return ret, ret %%timeit
jax.lax.scan(body_fn, m[0], m)
# 740 µs ± 252 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) You can check out a google colab here. In this case, |
Beta Was this translation helpful? Give feedback.
-
I managed to determine the cause of the issue - when JAX is checking the return type of a callback, it tries to compare Thanks for the assistance regarding |
Beta Was this translation helpful? Give feedback.
-
Note that in import jax
def f(x):
print("shape =", x.shape)
print(x)
return x
x = jax.numpy.arange(10)
jax.vmap(f)(x)
# shape = ()
# Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
# val = Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
# batch_dim = 0 So I suspect in your case you actually were seeing a scalar value, but the repr was confusing that. |
Beta Was this translation helpful? Give feedback.
I managed to determine the cause of the issue - when JAX is checking the return type of a callback, it tries to compare
shape
s - but a float doesn't have ashape
attribute. So the fix was to wrapjnp.array
to my returned value inside a proxy call torootFinder
. After doing this,scan
works just fine and I can even wrap the entire call in awhile_loop
and the performance is almost identical to the list comprehension solution I was using.Thanks for the assistance regarding
vmap
vsscan
@anh-tong !