diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py new file mode 100644 index 0000000..5388617 --- /dev/null +++ b/torchstore/_async_utils.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from typing import Callable, cast, Generic, TypeVar + +T = TypeVar("T") + + +class OnceCell(Generic[T]): + """Poor man's version of tokio::sync::OnceCell, except it's not threadsafe (maybe it is because of GIL?). + This is a cell that can be initialized exactly once.""" + + def __init__(self): + self._lock = asyncio.Lock() + self._value: T | None = None + self._initialized = False + + async def get_or_init(self, initializer) -> T: + if self._initialized: + return cast(T, self._value) + + async with self._lock: + if not self._initialized: + self._value = await initializer() + self._initialized = True + + return cast(T, self._value) + + def get(self) -> T: + if not self._initialized: + raise ValueError("Value not initialized yet") + return cast(T, self._value) + + +class SequentialExecutor: + """A simple executor that runs tasks sequentially in the current event loop. + This is mainly needed for RDMA operations, which will panic if concurrent requests are made (what the heck?). + """ + + def __init__(self): + self._queue = asyncio.Queue() + self._worker_task = None + + async def start_worker(self): + self._worker_task = asyncio.create_task(self._worker()) + + async def _worker(self): + while True: + try: + func, args, kwargs, response = await self._queue.get() + + if response.cancelled(): + continue # Caller gave up + + try: + result = await func(*args, **kwargs) + response.set_result(result) + except Exception as e: + response.set_exception(e) + + except Exception as outer_err: + # Log or handle the error + print(f"[SequentialExecutor] Worker crashed: {outer_err}") + + async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future: + fut = asyncio.Future() + await self._queue.put((func, args, kwargs, fut)) + return await fut diff --git a/torchstore/api.py b/torchstore/api.py index e38e7c6..9d10125 100644 --- a/torchstore/api.py +++ b/torchstore/api.py @@ -7,10 +7,10 @@ from typing import Any, Dict, List, Optional, Union import torch - from monarch.actor import get_or_spawn_controller import torchstore.state_dict_utils +from torchstore._async_utils import OnceCell, SequentialExecutor from torchstore.client import LocalClient from torchstore.controller import Controller from torchstore.storage_volume import StorageVolume @@ -26,7 +26,7 @@ DEFAULT_TORCHSTORE_NAME: str = "TorchStore" # cache for local clients -_local_clent_map: Dict[str, LocalClient] = {} +_local_client_map: Dict[str, OnceCell[LocalClient]] = {} async def initialize( @@ -94,14 +94,14 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None: """ controller = await _controller(store_name) await controller.teardown.call() - global _local_clent_map - _local_clent_map = {} + global _local_client_map + _local_client_map = {} def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None: """Reset the local client for a given store. Useful for refreshing client state after shutdown.""" - global _local_clent_map - _local_clent_map.pop(store_name, None) + global _local_client_map + _local_client_map.pop(store_name, None) async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller: @@ -124,19 +124,23 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient: >>> store_client = await client() >>> await store_client.put("my_key", tensor) """ - if store_name in _local_clent_map: - return _local_clent_map[store_name] - - controller = await _controller(store_name) - controller_strategy = await controller.get_controller_strategy.call_one() - - local_client = LocalClient( - controller=controller, - strategy=controller_strategy, - ) - _local_clent_map[store_name] = local_client + if store_name not in _local_client_map: + _local_client_map[store_name] = OnceCell() + + async def initializer(): + controller = await _controller(store_name) + controller_strategy = await controller.get_controller_strategy.call_one() + + executor = SequentialExecutor() + await executor.start_worker() + local_client = LocalClient( + controller=controller, + strategy=controller_strategy, + rdma_executor=executor, + ) + return local_client - return local_client + return await _local_client_map[store_name].get_or_init(initializer) async def put( diff --git a/torchstore/client.py b/torchstore/client.py index 52a6c16..9326b25 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -11,6 +11,8 @@ import torch from torch.distributed.tensor import DTensor +from torchstore._async_utils import SequentialExecutor + from torchstore.controller import ObjectType from torchstore.logging import LatencyTracker from torchstore.transport import Pipe, Request, TensorSlice @@ -19,6 +21,19 @@ logger = getLogger(__name__) +def _limit_concurrency(method): + """ + Decorator to limit concurrency of async methods using the instance's semaphore. + Assumes the instance has a self._semaphore attribute (asyncio.Semaphore). + """ + + async def wrapper(self, *args, **kwargs): + async with self._semaphore: + return await method(self, *args, **kwargs) + + return wrapper + + class LocalClient: """This class represents the local store, which exists on every process. Remote storage is handled by the client. @@ -28,9 +43,14 @@ def __init__( self, controller, strategy, + *, + rdma_executor: SequentialExecutor | None = None, + max_concurrent_requests: int = 4, ): self._controller = controller self.strategy = strategy + self.rdma_executor = rdma_executor + self._semaphore = asyncio.Semaphore(max_concurrent_requests) async def _locate_volumes(self, key: str): """Helper method to call locate_volumes and convert any error to KeyError for missing keys.""" @@ -40,6 +60,7 @@ async def _locate_volumes(self, key: str): raise KeyError(str(e)) from e @torch.no_grad + @_limit_concurrency async def put(self, key: str, value: Union[torch.Tensor, Any]): latency_tracker = LatencyTracker(f"put:{key}") request = Request.from_any(value) @@ -51,7 +72,7 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): pipe = Pipe(storage_volume) - await pipe.put_to_storage_volume(key, request) + await pipe.put_to_storage_volume(key, request, executor=self.rdma_executor) latency_tracker.track_step("put_to_storage_volume") await self._controller.notify_put.call(key, request.meta_only(), volume_id) @@ -59,6 +80,7 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): latency_tracker.track_e2e() @torch.no_grad + @_limit_concurrency async def get( self, key: str, @@ -235,7 +257,9 @@ async def _get_object(self, key: str): storage_volume = self.strategy.get_storage_volume(volume_id) pipe = Pipe(storage_volume) request = Request.from_any(None) - return await pipe.get_from_storage_volume(key, request) + return await pipe.get_from_storage_volume( + key, request, executor=self.rdma_executor + ) async def _get_tensor(self, key: str) -> torch.Tensor: """Fetches the tensor which is stored in one volume storage""" @@ -248,7 +272,9 @@ async def _get_tensor(self, key: str) -> torch.Tensor: # TODO: consolidate the logic here - None indicates it is an object request, # which is sematically inappropriate here. request = Request.from_any(None) - return await pipe.get_from_storage_volume(key, request) + return await pipe.get_from_storage_volume( + key, request, executor=self.rdma_executor + ) async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: """Fetches slices from all volume storages and stitch together to return the whole tensor""" @@ -267,7 +293,7 @@ async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: tensor_slice_request = Request.from_tensor_slice(tensor_slice) local_tensor = await pipe.get_from_storage_volume( - key, tensor_slice_request + key, tensor_slice_request, executor=self.rdma_executor ) partial_results.append((local_tensor, tensor_slice)) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 355fea5..93552e7 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,8 +11,9 @@ import torch from monarch.actor import Actor, endpoint -from torchstore.transport.buffers import TransportBuffer +from torchstore._async_utils import OnceCell, SequentialExecutor +from torchstore.transport.buffers import TransportBuffer from torchstore.transport.pipe import Request, TensorSlice from torchstore.utils import assemble_global_tensor, spawn_actors @@ -33,6 +34,15 @@ def __init__( ) -> None: self.store: StorageImpl = InMemoryStore() self.volume_id: str = id_func() + self._executor = OnceCell[SequentialExecutor]() + + async def get_executor(self) -> SequentialExecutor: + async def initializer() -> SequentialExecutor: + executor = SequentialExecutor() + await executor.start_worker() + return executor + + return await self._executor.get_or_init(initializer=initializer) @classmethod async def spawn( @@ -56,13 +66,17 @@ async def get_id(self) -> str: async def put( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> None: - await self.store.put(key, transport_buffer, request) + await self.store.put( + key, transport_buffer, request, executor=await self.get_executor() + ) @endpoint async def get( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> TransportBuffer: - return await self.store.get(key, transport_buffer, request) + return await self.store.get( + key, transport_buffer, request, executor=await self.get_executor() + ) @endpoint async def get_meta( @@ -81,13 +95,23 @@ class StorageImpl: """Abstract base class for storage implementations.""" async def put( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> Optional[TransportBuffer]: """Store data in the storage backend.""" raise NotImplementedError() async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> TransportBuffer: """Retrieve data from the storage backend.""" raise NotImplementedError() @@ -112,7 +136,7 @@ def __init__(self) -> None: def _build_full_tensor(self, key: str) -> None: logger.debug(f"Building full tensor for {key}") # we can also consider in the future not requiring the full tensor to be - # assembled, and instead only that the requested offsets are available + # assembled, and instead only that the requested offs are available # this is a performance optimization, but could be tricky to implement. assert self._has_full_tensor(key) @@ -186,7 +210,12 @@ def _handle_dtensor( } async def put( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> None: if request.is_object: self.kv[key] = {"obj": request.objects} @@ -194,7 +223,7 @@ async def put( # since we pass tensor=None to the transport buffer, # we allocate on the fly - tensor = await transport_buffer.read_into(tensor=None) + tensor = await transport_buffer.read_into(tensor=None, executor=executor) if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return @@ -202,7 +231,12 @@ async def put( self.kv[key] = tensor async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> TransportBuffer: if key not in self.kv: @@ -216,7 +250,7 @@ async def get( return transport_buffer if request.tensor_slice is None: - await transport_buffer.write_from(self.kv[key]) + await transport_buffer.write_from(self.kv[key], executor=executor) return transport_buffer # TODO: diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index f29579a..3f0ec14 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -32,7 +32,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: def rdma_available() -> bool: rdma_enabled = ( - os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1" + os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" ) # TODO: enable on this build return rdma_enabled and monarch_rdma_available() @@ -55,10 +55,17 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: """ raise NotImplementedError() - async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> torch.Tensor: raise NotImplementedError() - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: + raise NotImplementedError() + + async def drop(self, *, executor=None) -> None: raise NotImplementedError() @@ -137,7 +144,10 @@ def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) # send - async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor] = None, *, executor=None + ) -> torch.Tensor: + assert executor is not None, "RDMATransportBuffer requires an executor" if tensor is None: # allocate a tensor to return tensor = torch.empty(self.shape, dtype=self.dtype) @@ -159,7 +169,7 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor # TODO: gather instead of reading sequentially try: for idx, chunk in enumerate(chunked_byte_view): - await self.rdma_buffers[idx].read_into(chunk) + await executor.submit(self.rdma_buffers[idx].read_into, chunk) except Exception as e: logging.exception( f"Failed read_into, {tensor.shape=}, {tensor.dtype=}", exc_info=e @@ -169,7 +179,10 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return tensor # recv - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: + assert executor is not None, "RDMATransportBuffer requires an executor" if tensor is None: return @@ -188,7 +201,14 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: # the rdma buffer # TODO: gather instead of reading sequentially for idx, chunk in enumerate(chunked_byte_view): - await self.rdma_buffers[idx].write_from(chunk) + await executor.submit(self.rdma_buffers[idx].write_from, chunk) + + async def drop(self, *, executor=None) -> None: + assert executor is not None, "RDMATransportBuffer requires an executor" + if self.rdma_buffers is None: + return + for buffer in self.rdma_buffers: + await executor.submit(buffer.drop) class MonarchTransportBuffer(TransportBuffer): @@ -206,7 +226,12 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: return None # send - async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor] = None, *, executor=None + ) -> torch.Tensor: + + _ = executor + if tensor is not None: # if there is a tensor here, likely this is the 'inplace' case, # and we should return back a ptr to the original tensor @@ -218,9 +243,19 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return self.tensor # recv - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: + + _ = executor + self.tensor = tensor def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) self.tensor = other_buffer.tensor + + async def drop(self, *, executor=None) -> None: + # no-op + _ = executor + pass diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..234a3b1 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import copy from dataclasses import dataclass from logging import getLogger @@ -137,17 +138,16 @@ def __init__(self, storage_volume) -> None: def create_transport_buffer(self) -> TransportBuffer: # TODO: eventually this should be dependent on the connections available to a storage_volume if rdma_available(): - buffer_cls = RDMATransportBuffer + return RDMATransportBuffer() else: - buffer_cls = MonarchTransportBuffer - return buffer_cls() + return MonarchTransportBuffer() - async def put_to_storage_volume(self, key, request: Request): + async def put_to_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val - transport_buffer.allocate(tensor) - await transport_buffer.write_from(tensor) + await asyncio.to_thread(transport_buffer.allocate, tensor) + await transport_buffer.write_from(tensor, executor=executor) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead @@ -155,7 +155,9 @@ async def put_to_storage_volume(self, key, request: Request): key, transport_buffer, request.meta_only() ) - async def get_from_storage_volume(self, key, request: Request): + await transport_buffer.drop(executor=executor) + + async def get_from_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() @@ -179,4 +181,6 @@ async def get_from_storage_volume(self, key, request: Request): if transport_buffer.is_object: return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val) + ret = await transport_buffer.read_into(request.tensor_val, executor=executor) + await transport_buffer.drop(executor=executor) + return ret