Replies: 1 comment 1 reply
-
Thanks for the question. The issue you're running into is that you cannot use dynamically-shaped arrays within I suspect your best approach here will be to make import jax
import jax.numpy as jnp
from functools import partial
data = jnp.array([jnp.arange(32.0) for i in range(10)]) #NxM data matrix (this is to large to pass in batches)
padding = data.shape[1] # out-of-bound index
multiplets = jnp.array([
[0, 1, 2, padding, padding],
[5, 6, padding, padding, padding],
[9, 10, 11, 32, 24]
])
def process_data(x):
#This is a more complicated function that will operate on a subset of the data on the columns M
return jnp.nansum(x)
@jax.jit
def subset(carry, index):
data, multiplets = carry
subset_data = data.at[:, multiplets[index]].get(mode='fill', fill_value=jnp.nan)
results = process_data(subset_data)
return carry, results
final, result = jax.lax.scan(subset, (data, multiplets), jnp.arange(len(multiplets)))
print(result)
# [ 30. 110. 540.] Is that something like what you have in mind? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a large matrix of data of size MxN and a list of variable size indices (multiplets). The goal is to run a function over a subset of the data by slicing the matrix into smaller batches of size MxK where K is the length of the indices and K <= M . To illustrate this, I have made a toy example using the following code:
My approach is to:
Does anyone have any advice on how to approach this? Is there perhaps a better way? I found similar discussion using padding but - to the best of my knowledge - I have not seen anyone using the indices with padding to select items from a matrix. Thanks
Beta Was this translation helpful? Give feedback.
All reactions