Replies: 1 comment 2 replies
-
Hi - it's difficult to guess what's happening here without a runnable example. Can you create a complete minimal reproduction of the issue, with simplified versions of the functions that you're calling? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm new to jax, and I foolishly wrote a statistical function that uses a for loop with a nested fori_loop that swallows lots of memory on each iteration. Am rather stuck trying to work out a way to avoid out of memory--hopefully someone can steer me in the right direction. The function is, roughly as follows (note that probs is a 2D array):
Each time the fori_loop executes, my memory declines by about 220MB. The initialValPp array is only about 88KB, though the fori_loop executes thousands of times. Am guessing that each of the executions somehow adds to the memory being consumed.
I only need one value from 'out'. So I tried extracting a float from 'out'. That solves the memory problem when the function isn't jitted, but later it needs to be jitted and then I get a ConcretizationTypeError as a direct result of extracting the float.
I also tried a mock up version of the function that gets rid of the 'for' loop and replaces it with a fori_loop. That didn't help with the memory problem.
Any suggestions as to what I can try are welcome!
Beta Was this translation helpful? Give feedback.
All reactions