You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to calculate a hessian of a fairly complex function and a simple jax.hessian call is causing out-of-memory error.
My solution is to calculate the hessian block by block by utilizing indexing in a loop which works okay but little slow (slower than the numerical approximation).
Row by row hessian calculation (using a loop): 55 sec
Using finite difference method (from grad -> hessian): 37 sec
Here is a sample code demonstrating the logic (I am using vmap to make a fairer comparison but in the actual implementation I use a for loop to not get OOM):
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to calculate a hessian of a fairly complex function and a simple jax.hessian call is causing out-of-memory error.
My solution is to calculate the hessian block by block by utilizing indexing in a loop which works okay but little slow (slower than the numerical approximation).
Row by row hessian calculation (using a loop): 55 sec
Using finite difference method (from grad -> hessian): 37 sec
Here is a sample code demonstrating the logic (I am using vmap to make a fairer comparison but in the actual implementation I use a for loop to not get OOM):
Beta Was this translation helpful? Give feedback.
All reactions