Unexpected memory leak while sharding. #20585
Unanswered
aayushdesai
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am using sharding to parallelise the computation of my function over 8 CPU cores.
Here is the code:
A basic run-through of the code:
The problem:
I choose to shard the frequency(omega) array equally across all my 8 cores and replicate the observable and time arrays across my 8 cores. In total I have 2500 points in my time series (time and observable) and I have 400,000 points in my frequency grid. This is one calculation which I want to evaluate 10 (bootstrap_len) times in parallel.
At the end I should have the resulting array power of size 4,000,000.
When I use the memory profiler on my MacBook, the memory demand peaks at 20 GBs before the code crashes. From a rough calculation I don't expect the calculation to take more than a couple of 100 MBs. What am I missing?
Finally, the print statement at line 57 evaluates successfully, but as soon as I try to compute (or even access the first element of the power array) further with the power array, the system fails.
What am I missing?
System information:
Macbook Pro 2020 M1 chip
8 GB ram
Packages:
Jax version: 0.4.25
numpy version: 1.26.4
Thank you very much.
Beta Was this translation helpful? Give feedback.
All reactions