You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a question regarding jax.checkpoint and jax.custom_vjp that I'd like to seek your guidance on. In my research, I'm dealing with solving a large and highly sparse system of linear equations. To conserve memory, I prefer using sparse matrices. When solving this system of linear equations, I've utilized the scipy library, and thus I defined the operations and gradients using jax.custom_vjp. However, I've encountered situations where memory remains insufficient, so I'm hoping to use checkpoint to further alleviate memory overhead. Everything runs smoothly when I use checkpoint and custom_vjp separately. However, when I combine them, the program throws an error stating, "The numpy.ndarray conversion method array() was called on traced array with shape int64[8190720]." I understand this is likely due to my usage of scipy within the function. However, despite defining static_argnums when using checkpoint, it doesn't seem to take effect. Could you please advise on how to address this issue?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Dear JAX Team,
I have a question regarding
jax.checkpoint
andjax.custom_vjp
that I'd like to seek your guidance on. In my research, I'm dealing with solving a large and highly sparse system of linear equations. To conserve memory, I prefer using sparse matrices. When solving this system of linear equations, I've utilized thescipy
library, and thus I defined the operations and gradients usingjax.custom_vjp
. However, I've encountered situations where memory remains insufficient, so I'm hoping to usecheckpoint
to further alleviate memory overhead. Everything runs smoothly when I usecheckpoint
andcustom_vjp
separately. However, when I combine them, the program throws an error stating, "The numpy.ndarray conversion method array() was called on traced array with shape int64[8190720]." I understand this is likely due to my usage ofscipy
within the function. However, despite definingstatic_argnums
when usingcheckpoint
, it doesn't seem to take effect. Could you please advise on how to address this issue?Thank you very much for your time and assistance.
Beta Was this translation helpful? Give feedback.
All reactions