You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My code is a CLI app that uses JAX. I often run multiple instances of it in parallel. Sometimes there's not enough GPU memory, so if I start one of the instances, it errors out with a message like jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1439504488 bytes.
I would like to add a flag to my command that waits for enough GPU memory to become available, and then grab it. Right now I use this:
defget_used_and_total_gpu_ram_mib() ->tuple[int, int]:
command='nvidia-smi --query-gpu=memory.used,memory.total --format=csv,nounits,noheader'output=subprocess.check_output(command.split()).decode('ascii').strip().split('\n')
try:
line, =outputexceptValueErrorasvalue_error:
raiseNotImplementedError('Multiple GPUs not supported') fromvalue_errorreturntuple(map(int, line.split(',')))
defwait_for_gpu_availability() ->None:
"""Wait until GPU memory usage falls below 10%."""used_ram_mib, total_ram_mib=get_used_and_total_gpu_ram_mib()
memory_usage=used_ram_mib/total_ram_mibifmemory_usage<=0.1:
click.echo('GPU is already available, continuing.')
returnclick.echo('Waiting for GPU to become available...')
whilememory_usage>0.1:
time_module.sleep(1)
used_ram_mib, total_ram_mib=get_used_and_total_gpu_ram_mib()
memory_usage=used_ram_mib/total_ram_mibclick.echo('GPU is available, continuing.')
It's not great, because it's not atomic. I want to launch a dozen of instances of my program, each waiting for GPU. If I use this function, they might all start running as soon as enough RAM clears out, but because they're all running at once, only one will grab the RAM and the others will error out. I want an atomic solution, i.e. the program waits until the memory is available, allocates it, and only if it successfully got it does it proceed.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
My code is a CLI app that uses JAX. I often run multiple instances of it in parallel. Sometimes there's not enough GPU memory, so if I start one of the instances, it errors out with a message like
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1439504488 bytes.
I would like to add a flag to my command that waits for enough GPU memory to become available, and then grab it. Right now I use this:
It's not great, because it's not atomic. I want to launch a dozen of instances of my program, each waiting for GPU. If I use this function, they might all start running as soon as enough RAM clears out, but because they're all running at once, only one will grab the RAM and the others will error out. I want an atomic solution, i.e. the program waits until the memory is available, allocates it, and only if it successfully got it does it proceed.
Is that possible with JAX?
Beta Was this translation helpful? Give feedback.
All reactions