Skip to content
Discussion options

You must be logged in to vote

Hi, I don't think there's a simple plug-in method that will reduce the memory usage of the operation you've written.
The function you've written will return a 10^6 x 10^6-size matrix, i.e. one with a trillion entries. So just to store the output will require on the order of terabytes of memory.

If you're able to refactor your computational problem so that it doesn't need the entire matrix at once this would reduce the memory usage. So if e.g. you can deal with the rows one-by-one, you could loop through the rows with a jax.lax.scan. Inside the scan inner function you could vmap over however many rows you can fit in memory at once.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@sjiang23
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants