Skip to content
Discussion options

You must be logged in to vote

It might look something like this:

import jax

def cond_fun(carry):
  _, val1, val2 = carry
  return val1 > val2

def body_fun(carry):
  key, _, _ = carry
  key, subkey1, subkey2 = jax.random.split(key, 3)
  val1 = jax.random.uniform(subkey1)
  val2 = jax.random.uniform(subkey2)
  return key, val1, val2

key, val1, val2 = jax.lax.while_loop(cond_fun, body_fun, (jax.random.PRNGKey(0), 1, 0))
print(val1, val2)
# 0.7184181 0.80332327

This is the equivalent of the following Python while loop:

key = jax.random.PRNGKey(0)
val1 = 1
val2 = 0
while val1 > val2:
  key, subkey1, subkey2 = jax.random.split(key, 3)
  val1 = jax.random.uniform(subkey1)
  val2 = jax.random.uniform(subkey2)
print(val1, val2

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@LindoNkambule
Comment options

Answer selected by LindoNkambule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants