Is it possible to simultaneously perform computations on a TPU device and transfer new data onto it? #13011
Replies: 4 comments 8 replies
-
My intuition is that since JAX works async, you will get case 2 as a result if nothing blocks after |
Beta Was this translation helpful? Give feedback.
-
@cgarciae Thank you! However, I realised that there might be some blocking operations: @jax.jit
def train_step(params, opt_states, batch_tpu):
...
return params, opt_states, loss
...
for batch_cpu in batches_cpu:
batch_tpu = device_put(batch_cpu, device_tpu) # <- alpha
params, opt_states, loss = train_step(params, opt_states, batch_tpu) # <- beta
...
wandb.log({'loss': loss.item()}) # <- blocking |
Beta Was this translation helpful? Give feedback.
-
Also @mattjj @YouJiacheng |
Beta Was this translation helpful? Give feedback.
-
Besides, @Sea-Snell suggested https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device. It seems that this is exactly what I want to achieve. However, the documentation also says:
I don't quite understand this paragraph. From my understanding, as long as the buffer size is reasonable (e.g. 2) and not too large, it will not cause OOM, so the function is still useful for TPU. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Previously in https://twitter.com/ayaka14732/status/1585629469938569219.
Imagine that there is a training loop like this (pseudo code):
It would be more efficient if we do this:
Is it possible to achieve this in JAX?
Beta Was this translation helpful? Give feedback.
All reactions