-
Hi all, I am working on a project using Jax, and need to sum over slices of distinct size in an array. Here is a minimal example:
This works, and slice is jittable, assuming shapes is declared to be a static argument (which it is). However, the inner for-loop is very inefficient. Is there a way to implement this more efficiently using something like vmap? I have tried this, and the problem I have run into is that creating any smaller function to perform only one of the slices requires that function either to receive an element of shapes to slice with, which is no longer static, or to receive the full shapes tuple and index it dynamically with an indexing variable. Neither of these is jittable or works with vmap. However, because shapes is static, it seems like it should be possible to parallelize or make more efficient. Since I am new to Jax, I am wondering if a more experienced Jax user could tell me if this is possible and if so how to implement this. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
There's no way I can think of to do this sort of operation with One way you could rewrite this to be more efficient is by playing a trick involving a cumulative sum to eliminate the interior loop. I believe this function is equivalent to yours: def slice(x,shapes):
for shape in shapes:
ij = jnp.array(shape)
x_cuml = jnp.zeros(len(x) + 1, x.dtype).at[1:].set(x.cumsum())
x = x_cuml[ij[:, 1]] - x_cuml[ij[:, 0]]
return x However, if your indices reflect non-overlapping segments (which they seem to in your example), the most efficient approach may be to re-express your indices in terms of segment IDs and use the built-in from jax.ops import segment_sum
segment_ids = [jnp.array([0, 0, 1, 1, 1, 2, 3, 3 ,3 ,3]),
jnp.array([0, 1, 1, 1])]
def slice2(x, segment_ids):
for segment_id in segment_ids:
x = segment_sum(x, segment_id)
return x
slice2(x, segment_ids) |
Beta Was this translation helpful? Give feedback.
There's no way I can think of to do this sort of operation with
vmap
, because if you map the shapes it requires taking sums over dynamically-sized arrays (because, despite your inputs being static, any vmapped arguments will end up being dynamic).One way you could rewrite this to be more efficient is by playing a trick involving a cumulative sum to eliminate the interior loop. I believe this function is equivalent to yours:
However, if your indices reflect non-overlapping segments (which they seem t…