How to use lax.while_loop for array of logicals. #13868
-
I have an array 'a' consisting of 1's and the following while loop which adds 1.0 to the 3rd element every time and eventually makes it equal to 5.0. a = jnp.ones(5)
@jit
def cond1(state):
index, arr = state
return arr[index] < up_bound
@jit
def body1(state):
index, arr = state
temp = jnp.add(arr[index],1.0)
arr2 = arr.at[index].set(temp)
return (index, arr2)
index = 2
up_bound = 5.0
state = tuple([index,a])
index, arr = lax.while_loop(cond_fun=cond1, body_fun=body1, init_val=state)
print(arr) The answer is
But I need the operation to be done on a stack of such arrays: >>> b = jnp.stack((a,a,a))
>>> b
DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32) During the while loop the 3rd element of row-1, 4th element of row-2 and 5th element of row-3 should be increased with 1 to different upper bounds. That is, the inputs should be index = [2,3,4]
up_bound = [5.0, 6.0, 7.0] And ultimately my desired result should be
How should I do that ? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
First of all, you can simplify your body function like this if you wish: def body1(state):
index, arr = state
return (index, arr.at[index].add(1.0)) Second, it sounds like what you have in mind is a from jax import jit, lax, vmap
import jax.numpy as jnp
def f(arr, index, up_bound):
def cond(arr):
return arr[index] < up_bound
def body(arr):
return arr.at[index].add(1)
return lax.while_loop(cond_fun=cond, body_fun=body, init_val=arr)
a = jnp.ones(5)
index = 2
up_bound = 5.0
print(f(a, index, up_bound))
# [1. 1. 5. 1. 1.]
b = jnp.stack((a,a,a))
index = jnp.array([2,3,4])
up_bound = jnp.array([5.0, 6.0, 7.0])
print(vmap(f)(b, index, up_bound))
# [[1. 1. 5. 1. 1.]
# [1. 1. 1. 6. 1.]
# [1. 1. 1. 1. 7.]] |
Beta Was this translation helpful? Give feedback.
First of all, you can simplify your body function like this if you wish:
Second, it sounds like what you have in mind is a
vmap
operation, which you can achieve by wrapping your entire while loop in an outer function. It might look something like this: