Random number generation inside a while loop #18729
Answered
by
jakevdp
LindoNkambule
asked this question in
Q&A
-
Hi all, I am trying to generate random numbers inside
Any help/pointers to relevant documentation would be appreciated. Thanks |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Nov 29, 2023
Replies: 1 comment 1 reply
-
It might look something like this: import jax
def cond_fun(carry):
_, val1, val2 = carry
return val1 > val2
def body_fun(carry):
key, _, _ = carry
key, subkey1, subkey2 = jax.random.split(key, 3)
val1 = jax.random.uniform(subkey1)
val2 = jax.random.uniform(subkey2)
return key, val1, val2
key, val1, val2 = jax.lax.while_loop(cond_fun, body_fun, (jax.random.PRNGKey(0), 1, 0))
print(val1, val2)
# 0.7184181 0.80332327 This is the equivalent of the following Python key = jax.random.PRNGKey(0)
val1 = 1
val2 = 0
while val1 > val2:
key, subkey1, subkey2 = jax.random.split(key, 3)
val1 = jax.random.uniform(subkey1)
val2 = jax.random.uniform(subkey2)
print(val1, val2) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
LindoNkambule
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It might look something like this:
This is the equivalent of the following Python
while
loop: