Fori_loop array indexing issue #12432
-
Hi, I am trying to implement mini-batch gradient descent in JAX, however I encounter the following error when indexing inside the function
The way I compute indexes is by relying on the counter import jax
import jax.numpy as jnp
import optax
import numpy as np
def lin_model(params, x):
return jnp.dot(params, x)
def loss_mse(params, x, y, model):
yh = model(params, x)
mse = jnp.mean((yh - y)**2, axis=1)
return jnp.sqrt(jnp.sum(mse**2))
def fit_adam_sgd(params, x, y, model, loss_fn, epochs, lr, batch_size):
def update(params, x, y, opt_state):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y, model)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return loss, updates, grads, opt_state
def batch_step(i, carry):
batches = carry[2]
loss, updates, grads, opt_state = update(carry[0],
x[..., i*batches:(i+1)*batches],
y[..., i*batches:(i+1)*batches],
carry[1])
carry[0] = carry[0] + updates
carry[1] = opt_state
return carry
def epoch_step(i, carry):
carry = jax.lax.fori_loop(0, carry[2], batch_step, carry)
return carry
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
batches = int(x.shape[-1]/batch_size)
carry = [params, opt_state, batches]
res = jax.lax.fori_loop(0, epochs, epoch_step, carry)
return res[0]
np.random.seed(40)
n = 100
x = np.random.rand(3, n)
x[2, :] = 1.0
y = (2 + x[0,:] + x[1,:]).reshape(1, -1)
params = np.random.rand(y.shape[0], x.shape[0])
jit_fit_adam_sgd = jax.jit(fit_adam_sgd, static_argnums=(3,4), static_argnames=('epochs', 'batch_size'))
w = jit_fit_adam_sgd(params, x, y, lin_model, loss_mse, epochs=10000, lr=0.04, batch_size=10)
print(loss_mse(w, x, y, lin_model)) In words, I have two loops for mini-batch GD, the outer is over the number of epochs (considered static), the inner over the batches (here kept fixed but in general easy to shuffle the observations). Let us ignore that some observations may be left out from training by the current implementation, when x.shape[-1] mod batch_size != 0. I try to learn a simple linear regression. Final note. If I replace the epoch step with def epoch_step2(i, carry):
loss, updates, grads, opt_state = update(carry[0], x, y, carry[1])
carry[0] = carry[0] + updates
carry[1] = opt_state
return carry
thereby implementing batch GD, it works (as a proof of correctness). How can I get around the indexing error above? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The counter import jax.numpy as jnp
from jax import lax
x = jnp.arange(240).reshape(2, 3, 40)
batch_size = 4
i = 1
out_static = x[..., i * batch_size: (i + 1) * batch_size] # all indices must be static
print(out_static)
# [[[ 4 5 6 7]
# [ 44 45 46 47]
# [ 84 85 86 87]]
# [[124 125 126 127]
# [164 165 166 167]
# [204 205 206 207]]]
out_dynamic = lax.dynamic_slice(
x,
start_indices=(0, 0, i * batch_size), # Note: start_indices may be dynamic
slice_sizes=(*x.shape[:-1], batch_size)) # slice sizes must be static
print(out_dynamic)
# [[[ 4 5 6 7]
# [ 44 45 46 47]
# [ 84 85 86 87]]
# [[124 125 126 127]
# [164 165 166 167]
# [204 205 206 207]]] If you use |
Beta Was this translation helpful? Give feedback.
The counter
i
is in fact traced: the way thatfori_loop
works is to trace/compile the body function once in order to determine its behavior for abstract values ofi
, and then run that compiled code in sequence for every value ofi
. I suspect you could do what you want by replacing the numpy-style slicing with a call tolax.dynamic_slice
. Here's an example of how the two APIs compare: