Skip to content
Discussion options

You must be logged in to vote

For those interested, here is the solution due to @pschuh

def bar(X, Z):
  leading, *rest = Z.shape
  return jax.lax.map(lambda t: jnp.sum(X@t.T>=0, axis=0), Z.reshape(100, leading // 100, *rest)).flatten()

My understanding is that this combines foo and foo_map above: it runs foo on sub-matrices that allow everything to fit into memory, then stitches those together to give the final result.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ryan112358
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant