diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 5c7150e..64b0f7b 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,8 +11,10 @@ import torch from monarch.actor import Actor, endpoint -from torchstore.transport.buffers import TransportBuffer - +from torchstore.transport.buffers import ( + create_default_transport_buffer, + TransportBuffer, +) from torchstore.transport.pipe import Request, TensorSlice from torchstore.utils import assemble_global_tensor, spawn_actors @@ -59,10 +61,8 @@ async def put( await self.store.put(key, transport_buffer, request) @endpoint - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: - return await self.store.get(key, transport_buffer, request) + async def get(self, key: str, request: Request) -> TransportBuffer: + return await self.store.get(key, request) @endpoint async def get_meta( @@ -86,9 +86,7 @@ async def put( """Store data in the storage backend.""" raise NotImplementedError() - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: + async def get(self, key: str, request: Request) -> TransportBuffer: """Retrieve data from the storage backend.""" raise NotImplementedError() @@ -202,13 +200,13 @@ async def put( self.kv[key] = tensor - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: + async def get(self, key: str, request: Request) -> TransportBuffer: if key not in self.kv: raise KeyError(f"Key '{key}' not found. {list(self.kv.keys())=}") + transport_buffer = create_default_transport_buffer() + # TODO: clean up val = self.kv[key] if isinstance(val, dict) and "obj" in val: @@ -217,7 +215,7 @@ async def get( return transport_buffer if request.tensor_slice is None: - await transport_buffer.write_from(self.kv[key]) + transport_buffer.from_contiguous_tensor(self.kv[key]) return transport_buffer # TODO: @@ -228,7 +226,7 @@ async def get( for shard in self.kv[key].values(): if shard["slice"] == request.tensor_slice: - await transport_buffer.write_from(shard["tensor"]) + transport_buffer.from_contiguous_tensor(shard["tensor"]) return transport_buffer raise RuntimeError(f"Tensor slice {request.tensor_slice} not found in {key}") diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index debb585..3a897b3 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -4,6 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools + import logging import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -28,6 +32,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: ) +@functools.cache def rdma_available() -> bool: rdma_enabled = ( os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" @@ -35,6 +40,13 @@ def rdma_available() -> bool: return rdma_enabled and monarch_rdma_available() +def create_default_transport_buffer() -> TransportBuffer: + if rdma_available(): + return RDMATransportBuffer() + else: + return MonarchTransportBuffer() + + class TransportBuffer: finalize: bool = False is_object: bool = False @@ -47,10 +59,7 @@ def update(self, other_buffer: "TransportBuffer") -> None: self.objects = other_buffer.objects self.requires_meta = other_buffer.requires_meta - def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: - """Allocates internal buffers based on either an existing tensor - or a Tuple of (shape, dtype) - """ + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: raise NotImplementedError() async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: @@ -59,6 +68,9 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: async def write_from(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError() + async def drop(self) -> None: + pass + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor @@ -173,6 +185,12 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return tensor + async def drop(self) -> None: + if self.rdma_buffers is not None: + for buffer in self.rdma_buffers: + await buffer.drop() + self.tensor_refs = None + # recv async def write_from(self, tensor: Optional[torch.Tensor]) -> None: if tensor is None: @@ -195,6 +213,16 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: for idx, chunk in enumerate(chunked_byte_view): await self.rdma_buffers[idx].write_from(chunk) + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: + """The caller must ensure that the tensor lives long enough until the buffer is used.""" + assert tensor.is_contiguous(), "Tensor must be contiguous" + self.shape = tensor.shape + self.dtype = tensor.dtype + self.dim = tensor.dim() + self.rdma_buffers = [ + RDMABuffer(chunk) for chunk in self._create_byte_views_from_tensor(tensor) + ] + class MonarchTransportBuffer(TransportBuffer): """This interface is mostly a noop, intended to be used with Monarch's regular RPC. @@ -229,3 +257,6 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) self.tensor = other_buffer.tensor + + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: + self.tensor = tensor diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..9c77133 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -145,9 +145,12 @@ def create_transport_buffer(self) -> TransportBuffer: async def put_to_storage_volume(self, key, request: Request): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val - - transport_buffer.allocate(tensor) - await transport_buffer.write_from(tensor) + if tensor is not None: + # TODO: investigate why RDMA fails on CUDA tensors + tensor = tensor.cpu() + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + transport_buffer.from_contiguous_tensor(tensor) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead @@ -155,28 +158,16 @@ async def put_to_storage_volume(self, key, request: Request): key, transport_buffer, request.meta_only() ) - async def get_from_storage_volume(self, key, request: Request): - - transport_buffer = self.create_transport_buffer() - - # Certain buffers (RDMA) need to know the size of the tensor - # so we can allocate the right amount of memory locally. - # This can be avoided if the request contains a tensor slice. - # Could likely be optimized away in the future. - if transport_buffer.requires_meta and request.tensor_val is None: - meta = await self.storage_volume.get_meta.call_one(key, request.meta_only()) - transport_buffer.allocate(meta) - else: - transport_buffer.allocate(request.tensor_val) + await transport_buffer.drop() - # TODO: consider placing the buffer inside the request or vice versa - transport_buffer.update( - await self.storage_volume.get.call_one( - key, transport_buffer, request.meta_only() - ) + async def get_from_storage_volume(self, key, request: Request): + transport_buffer = await self.storage_volume.get.call_one( + key, request.meta_only() ) if transport_buffer.is_object: return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val) + ret = await transport_buffer.read_into(request.tensor_val) + await transport_buffer.drop() + return ret