diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index debb585..32dbb52 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -59,6 +59,10 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: async def write_from(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError() + async def drop(self) -> None: + """Clean up any resources held by this buffer. Override in subclasses if needed.""" + pass + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor @@ -72,6 +76,31 @@ def __init__(self) -> None: self.shape: Optional[torch.Size] = None self.dtype: Optional[torch.dtype] = None + async def drop(self) -> None: + """Explicitly clean up RDMA buffers to prevent kernel memory leak. + + When RDMA buffers are created, they register memory regions with the RDMA + hardware which pins pages in kernel memory. Without explicit cleanup, these + pages remain pinned even after the Python objects are garbage collected, + leading to a memory leak that manifests as unbounded Inactive(anon) growth. + """ + if self.rdma_buffers is not None: + for rdma_buf in self.rdma_buffers: + try: + # Drop the RDMA buffer to deregister the memory region + await rdma_buf.drop() + except Exception as e: + # Log but don't raise - cleanup should be best-effort + logging.warning(f"Failed to drop RDMA buffer during cleanup: {e}") + self.rdma_buffers = None + self.tensor_refs = None + + def __del__(self) -> None: + """Destructor that ensures RDMA buffers are cleaned up.""" + # Note: Not calling cleanup() here to avoid issues with destructor timing + # and to make cleanup explicit only where we control the lifecycle + pass + def __getstate__(self) -> Dict[str, Any]: # Any time that we serialize the transport buffer, the idea is # that tensors will be transported via tensor_enginer.RDMABuffer, so it makes diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..d869941 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -151,32 +151,44 @@ async def put_to_storage_volume(self, key, request: Request): # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead - await self.storage_volume.put.call_one( - key, transport_buffer, request.meta_only() - ) + try: + await self.storage_volume.put.call_one( + key, transport_buffer, request.meta_only() + ) + finally: + # Clean up the transport buffer after the put operation completes + # This is critical for RDMA buffers to deregister memory regions + await transport_buffer.drop() async def get_from_storage_volume(self, key, request: Request): transport_buffer = self.create_transport_buffer() - # Certain buffers (RDMA) need to know the size of the tensor - # so we can allocate the right amount of memory locally. - # This can be avoided if the request contains a tensor slice. - # Could likely be optimized away in the future. - if transport_buffer.requires_meta and request.tensor_val is None: - meta = await self.storage_volume.get_meta.call_one(key, request.meta_only()) - transport_buffer.allocate(meta) - else: - transport_buffer.allocate(request.tensor_val) - - # TODO: consider placing the buffer inside the request or vice versa - transport_buffer.update( - await self.storage_volume.get.call_one( - key, transport_buffer, request.meta_only() + try: + # Certain buffers (RDMA) need to know the size of the tensor + # so we can allocate the right amount of memory locally. + # This can be avoided if the request contains a tensor slice. + # Could likely be optimized away in the future. + if transport_buffer.requires_meta and request.tensor_val is None: + meta = await self.storage_volume.get_meta.call_one( + key, request.meta_only() + ) + transport_buffer.allocate(meta) + else: + transport_buffer.allocate(request.tensor_val) + + # TODO: consider placing the buffer inside the request or vice versa + transport_buffer.update( + await self.storage_volume.get.call_one( + key, transport_buffer, request.meta_only() + ) ) - ) - if transport_buffer.is_object: - return transport_buffer.objects + if transport_buffer.is_object: + return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val) + return await transport_buffer.read_into(request.tensor_val) + finally: + # Clean up the transport buffer after the get operation completes + # This is critical for RDMA buffers to deregister memory regions + await transport_buffer.drop()