Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions torchstore/transport/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor:
async def write_from(self, tensor: Optional[torch.Tensor]) -> None:
raise NotImplementedError()

async def drop(self) -> None:
"""Clean up any resources held by this buffer. Override in subclasses if needed."""
pass


class RDMATransportBuffer(TransportBuffer):
# TODO: when we try this with rdma, I should be able to write rdma directly to the tensor
Expand All @@ -72,6 +76,31 @@ def __init__(self) -> None:
self.shape: Optional[torch.Size] = None
self.dtype: Optional[torch.dtype] = None

async def drop(self) -> None:
"""Explicitly clean up RDMA buffers to prevent kernel memory leak.

When RDMA buffers are created, they register memory regions with the RDMA
hardware which pins pages in kernel memory. Without explicit cleanup, these
pages remain pinned even after the Python objects are garbage collected,
leading to a memory leak that manifests as unbounded Inactive(anon) growth.
"""
if self.rdma_buffers is not None:
for rdma_buf in self.rdma_buffers:
try:
# Drop the RDMA buffer to deregister the memory region
await rdma_buf.drop()
except Exception as e:
# Log but don't raise - cleanup should be best-effort
logging.warning(f"Failed to drop RDMA buffer during cleanup: {e}")
self.rdma_buffers = None
self.tensor_refs = None

def __del__(self) -> None:
"""Destructor that ensures RDMA buffers are cleaned up."""
# Note: Not calling cleanup() here to avoid issues with destructor timing
# and to make cleanup explicit only where we control the lifecycle
pass

def __getstate__(self) -> Dict[str, Any]:
# Any time that we serialize the transport buffer, the idea is
# that tensors will be transported via tensor_enginer.RDMABuffer, so it makes
Expand Down
54 changes: 33 additions & 21 deletions torchstore/transport/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,44 @@ async def put_to_storage_volume(self, key, request: Request):

# transporting tensors is handled by the buffer, so we don't want to send it
# via monarch RPC since that would generate considerable overhead
await self.storage_volume.put.call_one(
key, transport_buffer, request.meta_only()
)
try:
await self.storage_volume.put.call_one(
key, transport_buffer, request.meta_only()
)
finally:
# Clean up the transport buffer after the put operation completes
# This is critical for RDMA buffers to deregister memory regions
await transport_buffer.drop()

async def get_from_storage_volume(self, key, request: Request):

transport_buffer = self.create_transport_buffer()

# Certain buffers (RDMA) need to know the size of the tensor
# so we can allocate the right amount of memory locally.
# This can be avoided if the request contains a tensor slice.
# Could likely be optimized away in the future.
if transport_buffer.requires_meta and request.tensor_val is None:
meta = await self.storage_volume.get_meta.call_one(key, request.meta_only())
transport_buffer.allocate(meta)
else:
transport_buffer.allocate(request.tensor_val)

# TODO: consider placing the buffer inside the request or vice versa
transport_buffer.update(
await self.storage_volume.get.call_one(
key, transport_buffer, request.meta_only()
try:
# Certain buffers (RDMA) need to know the size of the tensor
# so we can allocate the right amount of memory locally.
# This can be avoided if the request contains a tensor slice.
# Could likely be optimized away in the future.
if transport_buffer.requires_meta and request.tensor_val is None:
meta = await self.storage_volume.get_meta.call_one(
key, request.meta_only()
)
transport_buffer.allocate(meta)
else:
transport_buffer.allocate(request.tensor_val)

# TODO: consider placing the buffer inside the request or vice versa
transport_buffer.update(
await self.storage_volume.get.call_one(
key, transport_buffer, request.meta_only()
)
)
)

if transport_buffer.is_object:
return transport_buffer.objects
if transport_buffer.is_object:
return transport_buffer.objects

return await transport_buffer.read_into(request.tensor_val)
return await transport_buffer.read_into(request.tensor_val)
finally:
# Clean up the transport buffer after the get operation completes
# This is critical for RDMA buffers to deregister memory regions
await transport_buffer.drop()
Loading