Implementing rejection sampling with vmap #11219
-
I am trying to implement rejection sampling in jax for an arbitrary function ('apply_fun') that is parameterized by 'params', and then vmap over different sets of parameters.
However, this fails either with a concretization error in the python while loop, or with this error:
in the lax.while_loop. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Hmmm, I think your code is fine... def sample_fun(rng, params, num_samples, ymax):
def rejection_sample(args):
rng, all_x, i= args
rng, split_rng = jax.random.split(rng)
x = jax.random.uniform(split_rng, minval=0, maxval=1, shape=(1,))
rng, split_rng = jax.random.split(rng)
y = jax.random.uniform(split_rng, minval=0, maxval=ymax, shape=(1,))
passed = (y < apply_fun(params, x)).astype(bool)
all_x = all_x.at[i].add((passed * x)[0])
i = i + passed[0] # fix here
return rng, all_x, i
all_x = np.zeros(num_samples)
_, all_x, _ = jax.lax.while_loop(lambda i: i[2] < num_samples, rejection_sample, (rng, all_x, 0))
return all_x |
Beta Was this translation helpful? Give feedback.
-
Unfortunately, it doesnt work, I added a small dummy function for a minimal self contained example to reproduce the error (based on your code for readability)
Which produces the aforementioned error:
|
Beta Was this translation helpful? Give feedback.
Hmmm, I think your code is fine...