Skip to content
Discussion options

You must be logged in to vote

The counter i is in fact traced: the way that fori_loop works is to trace/compile the body function once in order to determine its behavior for abstract values of i, and then run that compiled code in sequence for every value of i. I suspect you could do what you want by replacing the numpy-style slicing with a call to lax.dynamic_slice. Here's an example of how the two APIs compare:

import jax.numpy as jnp
from jax import lax

x = jnp.arange(240).reshape(2, 3, 40)

batch_size = 4
i = 1

out_static = x[..., i * batch_size: (i + 1) * batch_size]  # all indices must be static
print(out_static)
# [[[  4   5   6   7]
#   [ 44  45  46  47]
#   [ 84  85  86  87]]

#  [[124 125 126 127]
#   [164…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@guglielmogattiglio
Comment options

Answer selected by guglielmogattiglio
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants