GPU memory #9875
-
Hi, I am worrying about my code design choices with respect to the potential OOM it might cause. Q1: Is there a way to print the current GPU memory usage (without setting Q2: Imagine I have a code with some dead code parts def f(x):
y = g() # some dead code
return x I bet the dead code will be removed by the compilation ( |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Q1: I think there isn't a pytorch-like API that can interactively check device memory usage. There is a non-interactive but comprehensive method for device memory profiling https://jax.readthedocs.io/en/latest/device_memory_profiling.html |
Beta Was this translation helpful? Give feedback.
Q1: I think there isn't a pytorch-like API that can interactively check device memory usage. There is a non-interactive but comprehensive method for device memory profiling https://jax.readthedocs.io/en/latest/device_memory_profiling.html
Q2: I think default behavior of
jit
is lazy evaluation, only using abstract array(without data) during compilation. see https://jax.readthedocs.io/en/latest/_autosummary/jax.ensure_compile_time_eval.html