Skip to content
Discussion options

You must be logged in to vote

There's no way I can think of to do this sort of operation with vmap, because if you map the shapes it requires taking sums over dynamically-sized arrays (because, despite your inputs being static, any vmapped arguments will end up being dynamic).

One way you could rewrite this to be more efficient is by playing a trick involving a cumulative sum to eliminate the interior loop. I believe this function is equivalent to yours:

def slice(x,shapes):
  for shape in shapes:
    ij = jnp.array(shape)
    x_cuml = jnp.zeros(len(x) + 1, x.dtype).at[1:].set(x.cumsum())
    x = x_cuml[ij[:, 1]] - x_cuml[ij[:, 0]]
  return x

However, if your indices reflect non-overlapping segments (which they seem t…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@druidowm
Comment options

Answer selected by druidowm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants