Replies: 1 comment 5 replies
-
You can use the ahead-of-time cost analysis functionality, something like this: import jax
def f(x):
return jax.numpy.linalg.svd(x)[1]
x = jax.ShapeDtypeStruct((10000, 30000), 'float32')
print(jax.jit(f).lower(x).cost_analysis())
Read more at https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to experiment with different sharding strategies for my model and figure out which one results in less memory consumption. I think the easiest way to go about it is to get memory usage information directly from a pjit-compiled function, but I'm not sure if it's possible.
Any pointers or alternative ideas of how to go about such investigation will be much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions