From 99016f665324244dd00e6e4aa04aa71ce7f24c6d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 07:06:03 +0000 Subject: [PATCH 01/13] enable rdma by default --- torchstore/transport/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index f29579a..4c61bd4 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -32,7 +32,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: def rdma_available() -> bool: rdma_enabled = ( - os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1" + os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" ) # TODO: enable on this build return rdma_enabled and monarch_rdma_available() From 01644a23abb7075dcfd185eb29b8ca3faf9502d4 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 21:35:47 -0700 Subject: [PATCH 02/13] make torchstore concurrent safe --- torchstore/_async_utils.py | 74 +++++++++++++++++++++++++++++++++ torchstore/api.py | 40 ++++++++++-------- torchstore/client.py | 42 +++++++++++++++---- torchstore/storage_volume.py | 54 +++++++++++++++++++----- torchstore/transport/buffers.py | 22 +++++++--- torchstore/transport/pipe.py | 16 +++---- 6 files changed, 198 insertions(+), 50 deletions(-) create mode 100644 torchstore/_async_utils.py diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py new file mode 100644 index 0000000..cc2d2af --- /dev/null +++ b/torchstore/_async_utils.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from typing import Callable, cast, Generic, TypeVar + +T = TypeVar("T") + + +class OnceCell(Generic[T]): + """Poor man's version of tokio::sync::OnceCell, except it's not threadsafe (maybe it is because of GIL?). + This is a cell that can be initialized exactly once.""" + + def __init__(self): + self._lock = asyncio.Lock() + self._value: T | None = None + self._initialized = False + + async def get_or_init(self, initializer) -> T: + if self._initialized: + return cast(T, self._value) + + async with self._lock: + if not self._initialized: + self._value = await initializer() + self._initialized = True + + return cast(T, self._value) + + def get(self) -> T: + if not self._initialized: + raise ValueError("Value not initialized yet") + return cast(T, self._value) + + +class SequentialExecutor: + """A simple executor that runs tasks sequentially in the current event loop. + This is mainly needed for RDMA operations, which will panic if concurrent requests are made (what the heck?). + """ + + def __init__(self): + self._queue = asyncio.Queue() + self._worker_task = None + + async def start_worker(self): + self._worker_task = asyncio.create_task(self._worker()) + + async def _worker(self): + while True: + try: + func, args, kwargs, response = await self._queue.get() + + if response.cancelled(): + continue # Caller gave up + + try: + result = await func(*args, **kwargs) + response.set_result(result) + except Exception as e: + response.set_exception(e) + + except Exception as outer_err: + # Log or handle the error + print(f"[SequentialExecutor] Worker crashed: {outer_err}") + + await asyncio.sleep(0.001) # 1ms + + async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future: + fut = asyncio.Future() + await self._queue.put((func, args, kwargs, fut)) + return await fut diff --git a/torchstore/api.py b/torchstore/api.py index e38e7c6..9d10125 100644 --- a/torchstore/api.py +++ b/torchstore/api.py @@ -7,10 +7,10 @@ from typing import Any, Dict, List, Optional, Union import torch - from monarch.actor import get_or_spawn_controller import torchstore.state_dict_utils +from torchstore._async_utils import OnceCell, SequentialExecutor from torchstore.client import LocalClient from torchstore.controller import Controller from torchstore.storage_volume import StorageVolume @@ -26,7 +26,7 @@ DEFAULT_TORCHSTORE_NAME: str = "TorchStore" # cache for local clients -_local_clent_map: Dict[str, LocalClient] = {} +_local_client_map: Dict[str, OnceCell[LocalClient]] = {} async def initialize( @@ -94,14 +94,14 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None: """ controller = await _controller(store_name) await controller.teardown.call() - global _local_clent_map - _local_clent_map = {} + global _local_client_map + _local_client_map = {} def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None: """Reset the local client for a given store. Useful for refreshing client state after shutdown.""" - global _local_clent_map - _local_clent_map.pop(store_name, None) + global _local_client_map + _local_client_map.pop(store_name, None) async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller: @@ -124,19 +124,23 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient: >>> store_client = await client() >>> await store_client.put("my_key", tensor) """ - if store_name in _local_clent_map: - return _local_clent_map[store_name] - - controller = await _controller(store_name) - controller_strategy = await controller.get_controller_strategy.call_one() - - local_client = LocalClient( - controller=controller, - strategy=controller_strategy, - ) - _local_clent_map[store_name] = local_client + if store_name not in _local_client_map: + _local_client_map[store_name] = OnceCell() + + async def initializer(): + controller = await _controller(store_name) + controller_strategy = await controller.get_controller_strategy.call_one() + + executor = SequentialExecutor() + await executor.start_worker() + local_client = LocalClient( + controller=controller, + strategy=controller_strategy, + rdma_executor=executor, + ) + return local_client - return local_client + return await _local_client_map[store_name].get_or_init(initializer) async def put( diff --git a/torchstore/client.py b/torchstore/client.py index 52a6c16..52692fc 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -11,6 +11,8 @@ import torch from torch.distributed.tensor import DTensor +from torchstore._async_utils import SequentialExecutor + from torchstore.controller import ObjectType from torchstore.logging import LatencyTracker from torchstore.transport import Pipe, Request, TensorSlice @@ -19,6 +21,19 @@ logger = getLogger(__name__) +def _limit_concurrency(method): + """ + Decorator to limit concurrency of async methods using the instance's semaphore. + Assumes the instance has a self._semaphore attribute (asyncio.Semaphore). + """ + + async def wrapper(self, *args, **kwargs): + async with self._semaphore: + return await method(self, *args, **kwargs) + + return wrapper + + class LocalClient: """This class represents the local store, which exists on every process. Remote storage is handled by the client. @@ -28,9 +43,14 @@ def __init__( self, controller, strategy, + *, + rdma_executor: SequentialExecutor | None = None, + max_concurrent_requests: int = 32, ): self._controller = controller self.strategy = strategy + self.rdma_executor = rdma_executor + self._semaphore = asyncio.Semaphore(max_concurrent_requests) async def _locate_volumes(self, key: str): """Helper method to call locate_volumes and convert any error to KeyError for missing keys.""" @@ -40,6 +60,7 @@ async def _locate_volumes(self, key: str): raise KeyError(str(e)) from e @torch.no_grad + @_limit_concurrency async def put(self, key: str, value: Union[torch.Tensor, Any]): latency_tracker = LatencyTracker(f"put:{key}") request = Request.from_any(value) @@ -49,9 +70,9 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): # TorchstoreStrategy defined during intiailization storage_volume, volume_id = self.strategy.select_storage_volume() - pipe = Pipe(storage_volume) + pipe = Pipe(storage_volume, executor=self.rdma_executor) - await pipe.put_to_storage_volume(key, request) + await pipe.put_to_storage_volume(key, request, executor=self.rdma_executor) latency_tracker.track_step("put_to_storage_volume") await self._controller.notify_put.call(key, request.meta_only(), volume_id) @@ -59,6 +80,7 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): latency_tracker.track_e2e() @torch.no_grad + @_limit_concurrency async def get( self, key: str, @@ -233,9 +255,11 @@ async def _get_object(self, key: str): volume_map = await self._locate_volumes(key) volume_id, _ = volume_map.popitem() storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume) + pipe = Pipe(storage_volume, executor=self.rdma_executor) request = Request.from_any(None) - return await pipe.get_from_storage_volume(key, request) + return await pipe.get_from_storage_volume( + key, request, executor=self.rdma_executor + ) async def _get_tensor(self, key: str) -> torch.Tensor: """Fetches the tensor which is stored in one volume storage""" @@ -244,11 +268,13 @@ async def _get_tensor(self, key: str) -> torch.Tensor: # if the storage is a Tensor instead of DTensor, just fetch and return it. for volume_id, _ in volume_map.items(): storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume) + pipe = Pipe(storage_volume, executor=self.rdma_executor) # TODO: consolidate the logic here - None indicates it is an object request, # which is sematically inappropriate here. request = Request.from_any(None) - return await pipe.get_from_storage_volume(key, request) + return await pipe.get_from_storage_volume( + key, request, executor=self.rdma_executor + ) async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: """Fetches slices from all volume storages and stitch together to return the whole tensor""" @@ -258,7 +284,7 @@ async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: partial_results = [] for volume_id, storage_info in volume_map.items(): storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume) + pipe = Pipe(storage_volume, executor=self.rdma_executor) # fetch from all storage volumes, something like this # TODO: fix so we can request all tensor slices from a storage volume @@ -267,7 +293,7 @@ async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: tensor_slice_request = Request.from_tensor_slice(tensor_slice) local_tensor = await pipe.get_from_storage_volume( - key, tensor_slice_request + key, tensor_slice_request, executor=self.rdma_executor ) partial_results.append((local_tensor, tensor_slice)) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 355fea5..93552e7 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,8 +11,9 @@ import torch from monarch.actor import Actor, endpoint -from torchstore.transport.buffers import TransportBuffer +from torchstore._async_utils import OnceCell, SequentialExecutor +from torchstore.transport.buffers import TransportBuffer from torchstore.transport.pipe import Request, TensorSlice from torchstore.utils import assemble_global_tensor, spawn_actors @@ -33,6 +34,15 @@ def __init__( ) -> None: self.store: StorageImpl = InMemoryStore() self.volume_id: str = id_func() + self._executor = OnceCell[SequentialExecutor]() + + async def get_executor(self) -> SequentialExecutor: + async def initializer() -> SequentialExecutor: + executor = SequentialExecutor() + await executor.start_worker() + return executor + + return await self._executor.get_or_init(initializer=initializer) @classmethod async def spawn( @@ -56,13 +66,17 @@ async def get_id(self) -> str: async def put( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> None: - await self.store.put(key, transport_buffer, request) + await self.store.put( + key, transport_buffer, request, executor=await self.get_executor() + ) @endpoint async def get( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> TransportBuffer: - return await self.store.get(key, transport_buffer, request) + return await self.store.get( + key, transport_buffer, request, executor=await self.get_executor() + ) @endpoint async def get_meta( @@ -81,13 +95,23 @@ class StorageImpl: """Abstract base class for storage implementations.""" async def put( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> Optional[TransportBuffer]: """Store data in the storage backend.""" raise NotImplementedError() async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> TransportBuffer: """Retrieve data from the storage backend.""" raise NotImplementedError() @@ -112,7 +136,7 @@ def __init__(self) -> None: def _build_full_tensor(self, key: str) -> None: logger.debug(f"Building full tensor for {key}") # we can also consider in the future not requiring the full tensor to be - # assembled, and instead only that the requested offsets are available + # assembled, and instead only that the requested offs are available # this is a performance optimization, but could be tricky to implement. assert self._has_full_tensor(key) @@ -186,7 +210,12 @@ def _handle_dtensor( } async def put( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> None: if request.is_object: self.kv[key] = {"obj": request.objects} @@ -194,7 +223,7 @@ async def put( # since we pass tensor=None to the transport buffer, # we allocate on the fly - tensor = await transport_buffer.read_into(tensor=None) + tensor = await transport_buffer.read_into(tensor=None, executor=executor) if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return @@ -202,7 +231,12 @@ async def put( self.kv[key] = tensor async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request + self, + key: str, + transport_buffer: TransportBuffer, + request: Request, + *, + executor=None, ) -> TransportBuffer: if key not in self.kv: @@ -216,7 +250,7 @@ async def get( return transport_buffer if request.tensor_slice is None: - await transport_buffer.write_from(self.kv[key]) + await transport_buffer.write_from(self.kv[key], executor=executor) return transport_buffer # TODO: diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 4c61bd4..78dc41d 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -55,10 +55,14 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: """ raise NotImplementedError() - async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> torch.Tensor: raise NotImplementedError() - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: raise NotImplementedError() @@ -137,7 +141,10 @@ def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) # send - async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor] = None, *, executor=None + ) -> torch.Tensor: + assert executor is not None, "RDMATransportBuffer requires an executor" if tensor is None: # allocate a tensor to return tensor = torch.empty(self.shape, dtype=self.dtype) @@ -159,7 +166,7 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor # TODO: gather instead of reading sequentially try: for idx, chunk in enumerate(chunked_byte_view): - await self.rdma_buffers[idx].read_into(chunk) + await executor.submit(self.rdma_buffers[idx].read_into, chunk) except Exception as e: logging.exception( f"Failed read_into, {tensor.shape=}, {tensor.dtype=}", exc_info=e @@ -169,7 +176,10 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return tensor # recv - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: + assert executor is not None, "RDMATransportBuffer requires an executor" if tensor is None: return @@ -188,7 +198,7 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: # the rdma buffer # TODO: gather instead of reading sequentially for idx, chunk in enumerate(chunked_byte_view): - await self.rdma_buffers[idx].write_from(chunk) + await executor.submit(self.rdma_buffers[idx].write_from, chunk) class MonarchTransportBuffer(TransportBuffer): diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..a700af7 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -131,23 +131,23 @@ class Pipe: Transport wrapper for communicating from local clients to storage volumes. """ - def __init__(self, storage_volume) -> None: + def __init__(self, storage_volume, *, executor=None) -> None: self.storage_volume = storage_volume + self._executor = executor def create_transport_buffer(self) -> TransportBuffer: # TODO: eventually this should be dependent on the connections available to a storage_volume if rdma_available(): - buffer_cls = RDMATransportBuffer + return RDMATransportBuffer() else: - buffer_cls = MonarchTransportBuffer - return buffer_cls() + return MonarchTransportBuffer() - async def put_to_storage_volume(self, key, request: Request): + async def put_to_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val transport_buffer.allocate(tensor) - await transport_buffer.write_from(tensor) + await transport_buffer.write_from(tensor, executor=executor) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead @@ -155,7 +155,7 @@ async def put_to_storage_volume(self, key, request: Request): key, transport_buffer, request.meta_only() ) - async def get_from_storage_volume(self, key, request: Request): + async def get_from_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() @@ -179,4 +179,4 @@ async def get_from_storage_volume(self, key, request: Request): if transport_buffer.is_object: return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val) + return await transport_buffer.read_into(request.tensor_val, executor=executor) From b35981de571cdafc0b24b5578e84fd8c07a0057a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 22:34:56 -0700 Subject: [PATCH 03/13] fix tcp buffer arguments --- torchstore/transport/buffers.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 78dc41d..8576751 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -216,7 +216,12 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: return None # send - async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + async def read_into( + self, tensor: Optional[torch.Tensor] = None, *, executor=None + ) -> torch.Tensor: + + _ = executor + if tensor is not None: # if there is a tensor here, likely this is the 'inplace' case, # and we should return back a ptr to the original tensor @@ -228,7 +233,12 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return self.tensor # recv - async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + async def write_from( + self, tensor: Optional[torch.Tensor], *, executor=None + ) -> None: + + _ = executor + self.tensor = tensor def update(self, other_buffer: "TransportBuffer") -> None: From 9b5fb58499586feac96ce9a7eb83c8329821b2b8 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 00:12:49 -0700 Subject: [PATCH 04/13] clean up --- torchstore/client.py | 8 ++++---- torchstore/transport/pipe.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/torchstore/client.py b/torchstore/client.py index 52692fc..a8143a2 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -70,7 +70,7 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): # TorchstoreStrategy defined during intiailization storage_volume, volume_id = self.strategy.select_storage_volume() - pipe = Pipe(storage_volume, executor=self.rdma_executor) + pipe = Pipe(storage_volume) await pipe.put_to_storage_volume(key, request, executor=self.rdma_executor) latency_tracker.track_step("put_to_storage_volume") @@ -255,7 +255,7 @@ async def _get_object(self, key: str): volume_map = await self._locate_volumes(key) volume_id, _ = volume_map.popitem() storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume, executor=self.rdma_executor) + pipe = Pipe(storage_volume) request = Request.from_any(None) return await pipe.get_from_storage_volume( key, request, executor=self.rdma_executor @@ -268,7 +268,7 @@ async def _get_tensor(self, key: str) -> torch.Tensor: # if the storage is a Tensor instead of DTensor, just fetch and return it. for volume_id, _ in volume_map.items(): storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume, executor=self.rdma_executor) + pipe = Pipe(storage_volume) # TODO: consolidate the logic here - None indicates it is an object request, # which is sematically inappropriate here. request = Request.from_any(None) @@ -284,7 +284,7 @@ async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: partial_results = [] for volume_id, storage_info in volume_map.items(): storage_volume = self.strategy.get_storage_volume(volume_id) - pipe = Pipe(storage_volume, executor=self.rdma_executor) + pipe = Pipe(storage_volume) # fetch from all storage volumes, something like this # TODO: fix so we can request all tensor slices from a storage volume diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index a700af7..50d9eb6 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -131,9 +131,8 @@ class Pipe: Transport wrapper for communicating from local clients to storage volumes. """ - def __init__(self, storage_volume, *, executor=None) -> None: + def __init__(self, storage_volume) -> None: self.storage_volume = storage_volume - self._executor = executor def create_transport_buffer(self) -> TransportBuffer: # TODO: eventually this should be dependent on the connections available to a storage_volume From d7b4f006da775f01f5a27f9110d4d2b0b78797e6 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:14:43 -0700 Subject: [PATCH 05/13] tick --- torchstore/_async_utils.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py index cc2d2af..5cbc30c 100644 --- a/torchstore/_async_utils.py +++ b/torchstore/_async_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +import time from typing import Callable, cast, Generic, TypeVar T = TypeVar("T") @@ -36,6 +37,26 @@ def get(self) -> T: return cast(T, self._value) +class Interval: + """Poor man's version of tokio::time::Interval""" + + def __init__(self, period: float): + self.period = period + self.next_tick = time.monotonic() + period + + async def tick(self): + """Wait until the next tick instant""" + now = time.monotonic() + sleep_duration = self.next_tick - now + + if sleep_duration < 0: + sleep_duration = 0 + + await asyncio.sleep(sleep_duration) + + self.next_tick += self.period + + class SequentialExecutor: """A simple executor that runs tasks sequentially in the current event loop. This is mainly needed for RDMA operations, which will panic if concurrent requests are made (what the heck?). @@ -49,6 +70,7 @@ async def start_worker(self): self._worker_task = asyncio.create_task(self._worker()) async def _worker(self): + interval = Interval(0.001) # 1ms while True: try: func, args, kwargs, response = await self._queue.get() @@ -66,7 +88,7 @@ async def _worker(self): # Log or handle the error print(f"[SequentialExecutor] Worker crashed: {outer_err}") - await asyncio.sleep(0.001) # 1ms + await interval.tick() async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future: fut = asyncio.Future() From 244bedde9ca698597f31387375c14013b543ebb7 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:15:51 -0700 Subject: [PATCH 06/13] tick --- torchstore/_async_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py index 5cbc30c..6c63391 100644 --- a/torchstore/_async_utils.py +++ b/torchstore/_async_utils.py @@ -54,7 +54,7 @@ async def tick(self): await asyncio.sleep(sleep_duration) - self.next_tick += self.period + self.next_tick += now + self.period class SequentialExecutor: From 6d08018c86c64f5ee95a6c7d38a8b84a09765b3d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:23:33 -0700 Subject: [PATCH 07/13] drop buffers after use --- torchstore/transport/buffers.py | 15 +++++++++++++++ torchstore/transport/pipe.py | 6 +++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 8576751..3f0ec14 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -65,6 +65,9 @@ async def write_from( ) -> None: raise NotImplementedError() + async def drop(self, *, executor=None) -> None: + raise NotImplementedError() + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor @@ -200,6 +203,13 @@ async def write_from( for idx, chunk in enumerate(chunked_byte_view): await executor.submit(self.rdma_buffers[idx].write_from, chunk) + async def drop(self, *, executor=None) -> None: + assert executor is not None, "RDMATransportBuffer requires an executor" + if self.rdma_buffers is None: + return + for buffer in self.rdma_buffers: + await executor.submit(buffer.drop) + class MonarchTransportBuffer(TransportBuffer): """This interface is mostly a noop, intended to be used with Monarch's regular RPC. @@ -244,3 +254,8 @@ async def write_from( def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) self.tensor = other_buffer.tensor + + async def drop(self, *, executor=None) -> None: + # no-op + _ = executor + pass diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 50d9eb6..197af00 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -154,6 +154,8 @@ async def put_to_storage_volume(self, key, request: Request, *, executor=None): key, transport_buffer, request.meta_only() ) + await transport_buffer.drop(executor=executor) + async def get_from_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() @@ -178,4 +180,6 @@ async def get_from_storage_volume(self, key, request: Request, *, executor=None) if transport_buffer.is_object: return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val, executor=executor) + ret = await transport_buffer.read_into(request.tensor_val, executor=executor) + await transport_buffer.drop(executor=executor) + return ret From 6fea2b650f28d482289cdbe4e6f64f9d61a71d7a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:24:13 -0700 Subject: [PATCH 08/13] decrease default concurrency --- torchstore/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/client.py b/torchstore/client.py index a8143a2..ff9a2d8 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -45,7 +45,7 @@ def __init__( strategy, *, rdma_executor: SequentialExecutor | None = None, - max_concurrent_requests: int = 32, + max_concurrent_requests: int = 16, ): self._controller = controller self.strategy = strategy From 634461ca00c07e8b0501a5b8ac1b69301973de4e Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:33:44 -0700 Subject: [PATCH 09/13] fix tick bug --- torchstore/_async_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py index 6c63391..996318f 100644 --- a/torchstore/_async_utils.py +++ b/torchstore/_async_utils.py @@ -46,15 +46,14 @@ def __init__(self, period: float): async def tick(self): """Wait until the next tick instant""" - now = time.monotonic() - sleep_duration = self.next_tick - now + sleep_duration = self.next_tick - time.monotonic() if sleep_duration < 0: sleep_duration = 0 await asyncio.sleep(sleep_duration) - self.next_tick += now + self.period + self.next_tick += time.monotonic() + self.period class SequentialExecutor: From 2dc4992da31aae8a71e6e4d5a05485459062609d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:44:07 -0700 Subject: [PATCH 10/13] don't drop rdma buffers --- torchstore/transport/pipe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 197af00..9252ecd 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -154,7 +154,8 @@ async def put_to_storage_volume(self, key, request: Request, *, executor=None): key, transport_buffer, request.meta_only() ) - await transport_buffer.drop(executor=executor) + # TODO: figure out why this hangs + # await transport_buffer.drop(executor=executor) async def get_from_storage_volume(self, key, request: Request, *, executor=None): @@ -181,5 +182,6 @@ async def get_from_storage_volume(self, key, request: Request, *, executor=None) return transport_buffer.objects ret = await transport_buffer.read_into(request.tensor_val, executor=executor) - await transport_buffer.drop(executor=executor) + # TODO: figure out why this hangs + # await transport_buffer.drop(executor=executor) return ret From 35f600c50a658a9289f2256b1c8e7652cefc1f98 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 01:51:16 -0700 Subject: [PATCH 11/13] no tick --- torchstore/_async_utils.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py index 996318f..cc2d2af 100644 --- a/torchstore/_async_utils.py +++ b/torchstore/_async_utils.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import asyncio -import time from typing import Callable, cast, Generic, TypeVar T = TypeVar("T") @@ -37,25 +36,6 @@ def get(self) -> T: return cast(T, self._value) -class Interval: - """Poor man's version of tokio::time::Interval""" - - def __init__(self, period: float): - self.period = period - self.next_tick = time.monotonic() + period - - async def tick(self): - """Wait until the next tick instant""" - sleep_duration = self.next_tick - time.monotonic() - - if sleep_duration < 0: - sleep_duration = 0 - - await asyncio.sleep(sleep_duration) - - self.next_tick += time.monotonic() + self.period - - class SequentialExecutor: """A simple executor that runs tasks sequentially in the current event loop. This is mainly needed for RDMA operations, which will panic if concurrent requests are made (what the heck?). @@ -69,7 +49,6 @@ async def start_worker(self): self._worker_task = asyncio.create_task(self._worker()) async def _worker(self): - interval = Interval(0.001) # 1ms while True: try: func, args, kwargs, response = await self._queue.get() @@ -87,7 +66,7 @@ async def _worker(self): # Log or handle the error print(f"[SequentialExecutor] Worker crashed: {outer_err}") - await interval.tick() + await asyncio.sleep(0.001) # 1ms async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future: fut = asyncio.Future() From bf36583d5c218c13b41e3c3dcebcea32fe090a91 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 02:19:24 -0700 Subject: [PATCH 12/13] enable drop again --- torchstore/transport/pipe.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 9252ecd..197af00 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -154,8 +154,7 @@ async def put_to_storage_volume(self, key, request: Request, *, executor=None): key, transport_buffer, request.meta_only() ) - # TODO: figure out why this hangs - # await transport_buffer.drop(executor=executor) + await transport_buffer.drop(executor=executor) async def get_from_storage_volume(self, key, request: Request, *, executor=None): @@ -182,6 +181,5 @@ async def get_from_storage_volume(self, key, request: Request, *, executor=None) return transport_buffer.objects ret = await transport_buffer.read_into(request.tensor_val, executor=executor) - # TODO: figure out why this hangs - # await transport_buffer.drop(executor=executor) + await transport_buffer.drop(executor=executor) return ret From a10d427a412e35d3505beff44846e3a693bfcd42 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 15:10:32 -0700 Subject: [PATCH 13/13] tune --- torchstore/_async_utils.py | 2 -- torchstore/client.py | 2 +- torchstore/transport/pipe.py | 3 ++- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/torchstore/_async_utils.py b/torchstore/_async_utils.py index cc2d2af..5388617 100644 --- a/torchstore/_async_utils.py +++ b/torchstore/_async_utils.py @@ -66,8 +66,6 @@ async def _worker(self): # Log or handle the error print(f"[SequentialExecutor] Worker crashed: {outer_err}") - await asyncio.sleep(0.001) # 1ms - async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future: fut = asyncio.Future() await self._queue.put((func, args, kwargs, fut)) diff --git a/torchstore/client.py b/torchstore/client.py index ff9a2d8..9326b25 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -45,7 +45,7 @@ def __init__( strategy, *, rdma_executor: SequentialExecutor | None = None, - max_concurrent_requests: int = 16, + max_concurrent_requests: int = 4, ): self._controller = controller self.strategy = strategy diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 197af00..234a3b1 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import copy from dataclasses import dataclass from logging import getLogger @@ -145,7 +146,7 @@ async def put_to_storage_volume(self, key, request: Request, *, executor=None): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val - transport_buffer.allocate(tensor) + await asyncio.to_thread(transport_buffer.allocate, tensor) await transport_buffer.write_from(tensor, executor=executor) # transporting tensors is handled by the buffer, so we don't want to send it