Skip to content
Discussion options

You must be logged in to vote

First of all, you can simplify your body function like this if you wish:

def body1(state):
    index, arr = state
    return (index, arr.at[index].add(1.0))

Second, it sounds like what you have in mind is a vmap operation, which you can achieve by wrapping your entire while loop in an outer function. It might look something like this:

from jax import jit, lax, vmap
import jax.numpy as jnp

def f(arr, index, up_bound):
  def cond(arr):
    return arr[index] < up_bound
  def body(arr):
    return arr.at[index].add(1)
  return lax.while_loop(cond_fun=cond, body_fun=body, init_val=arr)

a = jnp.ones(5)
index = 2
up_bound = 5.0
print(f(a, index, up_bound))
# [1. 1. 5. 1. 1.]

b = jnp.stack((a,a,a

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@abhiroop513
Comment options

Answer selected by abhiroop513
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants