Scan over layers with sharded parameters increases memory usage. #18735
-
I'm having a problem where scanning over layers whose parameters are sharded causes OOM. My naive implementation is for the forward pass to loop over layers, but JIT-ing the forward takes a long time. Scanning over layers is recommended to improve compilation time, but seems to increase memory usage. Below is a minimal repro. If
Here are the biggest allocation:
Is there a way to prevent this or debug what's going on? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! This is just a guess, but as mentioned in the practical notes for |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
This is just a guess, but as mentioned in the practical notes for
jax.remat
, usingscan
can defeat some XLA rematerialization optimizations and thus cause OOMs. Maybe try applyingjax.remat
to thelayer
function, and see how it affects performance? You can use the policies described in those docs to control things more precisely.