Access a registered buffer is very slow #11493
-
Hello, I implemented MoCo in Pytorch lightning. I was surprised to see that my lightning version was slower than Pytorch's and I ran the profiler to check which function is slow. I can't share all my code but here are the relevant parts: class MoCoModel(LightningModule):
def __init__(
...
) -> None:
...
self.register_buffer('queue', torch.randn(queue.feature_dim, queue.size))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _update_queue(self, x: Tensor) -> None:
x = self.concat_all_gather_without_backprop(x)
#batch_size = x.shape[0]
batch_size = self._get_batch_size(x)
# for simplicity
ptr = self._get_ptr()
#ptr = int(self.queue_ptr)
self._assert(batch_size)
#assert self.queue_size % batch_size == 0
# replace the keys at ptr (dequeue and enqueue)
x = self._transpose(x)
self._assign_in_queue(x, ptr, batch_size)
#self.queue[:, ptr: ptr + batch_size] = x.T
# move pointer
ptr = self._compute_ptr(ptr, batch_size)
self._assign_ptr(ptr)
#ptr = (ptr + batch_size) % self.queue_size
def _get_batch_size(self, x):
return x.shape[0]
def _get_ptr(self):
return int(self.queue_ptr)
def _assert(self, batch_size):
assert self.queue_size % batch_size == 0
def _assign_ptr(self, ptr):
self.queue_ptr[0] = ptr
def _compute_ptr(self, batch_size, ptr):
return (ptr + batch_size) % self.queue_size
def _transpose(self, x):
return x.T
def _assign_in_queue(self, x, ptr, batch_size):
self.queue[:, ptr: ptr + batch_size] = x
def training_step(self, batch):
...
self._update_queue(k) Here is the output of running simple profiler:
As we can see a large time is spent in
The function I tested with DDP and SingleDevice strategy that resulted in the same kind of slow down on a SLURM cluster environment. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Fixed it, lightning is now as fast as my previous implementation, the problem was elsewhere but I didn't detect it using the profiler because of the asynchronous computation from GPUs which were not synchronized during profiling. |
Beta Was this translation helpful? Give feedback.
-
@juliendenize |
Beta Was this translation helpful? Give feedback.
Fixed it, lightning is now as fast as my previous implementation, the problem was elsewhere but I didn't detect it using the profiler because of the asynchronous computation from GPUs which were not synchronized during profiling.