-
I'm attempting to implement a computation -- I can't figure out how to effectively do it, so I'm hoping that someone smarter can provide some tips. I have two arrays: [ [0, 1, 2],
[1, 1, 2],
[1, 0, 1] ]
[ [1.0, 1.0, 3.0],
[1.0, 2.0, 3.0],
[0.0, 2.0, 3.0] ] I want to use the first array as a sort of index array that reorders the second array in a scan-like pattern along the last axis -- in the second array, the reordering occurs from the first index of the second axis all the way to the current scan index -- like so: # First scan iter
# Index: [0, 1, 1]
# View of target: [[1.0, 1.0, 0.0]]
# Return: [[1.0, 1.0, 1.0]]
# Second scan iter
# Index: [1, 1, 0]
# View of target:
[[1.0, 1.0],
[1.0, 2.0],
[1.0, 2.0]]
# Result:
[[1.0, 2.0]
[1.0, 2.0]
[1.0, 1.0]]
# Final scan iter
# Index: [2, 2, 1]
# View of target:
[[1.0, 2.0, 3.0]
[1.0, 2.0, 3.0]
[1.0, 1.0, 3.0]
# Result:
[[1.0, 1.0, 3.0]
[1.0, 1.0, 3.0]
[1.0, 2.0, 3.0]] So -- the final result shape is static -- but the intermediate shapes are like dynamic slices. But the changes from previous parts of the reordering computation propagate forward (which is why I called this "scan-like"). Does anyone have any ideas for this sort of thing? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
If I'm understanding correctly, I think something like this should work: import jax.numpy as jnp
from jax import lax
indices = jnp.array([[0, 1, 2],
[1, 1, 2],
[1, 0, 1]])
data = jnp.array([[1.0, 1.0, 3.0],
[1.0, 2.0, 3.0],
[0.0, 2.0, 3.0]])
def loop(i, data):
r = jnp.arange(indices.shape[1])
ind = indices[:, i]
return jnp.where(r > i, data, data[ind])
result = lax.fori_loop(0, indices.shape[1], loop, data)
print(result)
# [[1. 1. 3.]
# [1. 1. 3.]
# [1. 2. 3.]] |
Beta Was this translation helpful? Give feedback.
If I'm understanding correctly, I think something like this should work: