Defrag hangs when using devicearray-based dataloader #12620
Unanswered
lucaslingle
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I've noticed an interesting phenomenon when trying to run
jax.lib.xla_bridge.get_backend().defragment()
on TPU v3-8, as suggested here, and was wondering if anyone here might have any insights as to what caused it.Dataloader 1
I have a dataloader for sequences that uses jax for shuffling, which can be simplified to
(toggle to show)
When I use it, my training loop eventually runs into an issue with memory fragmentation at the beginning of the second epoch (judging by the step number). To resolve this, I rewrote my training loop to call defrag in a try-except statement wrapped around a
train_op
, which is pmapped.(toggle to show)
Unfortunately, this call to defrag causes the script to hang for upwards of half an hour. Defrag calls placed elsewhere run in about 30 seconds, but they do not resolve this issue, which occurs on the first training step after
data_iter_v1
callsepoch_iter_v1
to create a generator for the second epoch (see first code snippet).Dataloader 2
I wrote another version of my dataloader that calls
np.asarray
after permuting, thus converting them fromjax.DeviceArray
s back to numpy arrays.(toggle to show)
In this case, defrag is not even needed, and my script runs fine. However, I am wondering why this made such a difference. The two main DeviceArrays created by
epoch_iter_v1
, namelytensor_inp
andtensor_tgt
, presumably reside in TPU VM host memory, not accelerator memory, so I am not sure why the defrag time was so high. In my understanding, the only additional tensors in accelerator memory would be those moved there bycommon_utils.shard(batch)
.I was also monitoring the total number of "live buffers" and their total size with
jax.lib.xla_bridge.get_backend().live_buffers()
and these numbers stay constant in former case prior to the script hanging, as well as in the latter case for the entirety of training. However, I observed four fewer live buffers in the latter case.So my questions are...
tensor_inp
andtensor_tgt
somehow passed to accelerator memory?Thanks in advance for any insights you can share!
Beta Was this translation helpful? Give feedback.
All reactions