-
a silly question about jax: when one use jax to generate ur mcmc traj, one need the rng key for each step, right? how to handle that? def mcmc(params, keys):
for i in range(20):
x = ..... + eps * random.normal(key_i, shape)
return x Do I need to pass the list of keys into this function? |
Beta Was this translation helpful? Give feedback.
Answered by
JiahaoYao
Jul 3, 2021
Replies: 1 comment
-
okay, solved |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
JiahaoYao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
okay, solved