Skip to content
Discussion options

You must be logged in to vote

Hmmm, I think your code is fine...

def sample_fun(rng, params, num_samples, ymax):
         def rejection_sample(args):
            rng, all_x, i= args
            rng, split_rng = jax.random.split(rng)
            x = jax.random.uniform(split_rng, minval=0, maxval=1, shape=(1,))
            rng, split_rng = jax.random.split(rng)
            y = jax.random.uniform(split_rng, minval=0, maxval=ymax, shape=(1,))

            passed = (y < apply_fun(params, x)).astype(bool)
            all_x = all_x.at[i].add((passed * x)[0])
            i = i + passed[0] # fix here
            return rng, all_x, i
         all_x = np.zeros(num_samples)
         _, all_x, _ = jax.lax.while_loop(lambda i: i[2] < 

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by Binbose
Comment options

You must be logged in to vote
4 replies
@YouJiacheng
Comment options

@Binbose
Comment options

@YouJiacheng
Comment options

@Binbose
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants