diff --git a/tests/test_models.py b/tests/test_models.py index ee39559..41a8baf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -122,24 +122,23 @@ async def do_get(self): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_basic(strategy_params, use_rdma): +async def test_basic(strategy_params, transport_type): # FSDP put_mesh_shape = (1,) get_mesh_shape = (1,) - await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma) + await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type) @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_resharding(strategy_params, use_rdma): +async def test_resharding(strategy_params, transport_type): # FSDP put_mesh_shape = (4,) get_mesh_shape = (8,) - await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma) + await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type) -async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma): - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" +async def _do_test(put_mesh_shape, get_mesh_shape, strategy, transport_type): ts.init_logging() logger.info(f"Testing with strategy: {strategy}") @@ -147,7 +146,9 @@ async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma): put_world_size = math.prod(put_mesh_shape) await ts.initialize( num_storage_volumes=put_world_size if strategy is not None else 1, - strategy=strategy, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, ) try: with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/test_resharding.py b/tests/test_resharding.py index 9b749d0..4b434ff 100644 --- a/tests/test_resharding.py +++ b/tests/test_resharding.py @@ -8,16 +8,16 @@ import os import tempfile from logging import getLogger -from typing import List, Tuple, Union +from typing import List, Tuple, Type, Union import pytest import torch import torchstore as ts - from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset +from torchstore.transport import TransportType from torchstore.utils import get_local_tensor, spawn_actors from .utils import DTensorActor, main, transport_plus_strategy_params @@ -27,7 +27,7 @@ @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_1d_resharding(strategy_params, use_rdma): +async def test_1d_resharding(strategy_params, transport_type): _, strategy = strategy_params for put_mesh_shape, get_mesh_shape in [ @@ -47,13 +47,13 @@ async def test_1d_resharding(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=[Shard(get_sharding_dim)], strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_2d_to_2d_resharding(strategy_params, use_rdma): +async def test_2d_to_2d_resharding(strategy_params, transport_type): _, strategy = strategy_params put_mesh_shape = get_mesh_shape = (2, 2) @@ -69,13 +69,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=[Shard(dim) for dim in get_sharding_dims], strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_1d_to_2d_resharding(strategy_params, use_rdma): +async def test_1d_to_2d_resharding(strategy_params, transport_type): _, strategy = strategy_params put_mesh_shape = (4,) @@ -92,13 +92,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=[Shard(dim) for dim in get_sharding_dims], strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_2d_to_1d_resharding(strategy_params, use_rdma): +async def test_2d_to_1d_resharding(strategy_params, transport_type): _, strategy = strategy_params put_mesh_shape = (2, 2) @@ -115,13 +115,13 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=[Shard(dim) for dim in get_sharding_dims], strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_data_parallel(strategy_params, use_rdma): +async def test_data_parallel(strategy_params, transport_type): _, strategy = strategy_params # # 1d @@ -134,7 +134,7 @@ async def test_data_parallel(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=placements, strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) # 2d -> 1d @@ -149,7 +149,7 @@ async def test_data_parallel(strategy_params, use_rdma): get_mesh_shape=get_mesh_shape, get_placements=[Shard(1)], strategy=strategy, - use_rdma=use_rdma, + transport_type=transport_type, ) @@ -158,8 +158,8 @@ async def _test_resharding( put_placements: List[Union[Replicate, Shard]], get_mesh_shape: Tuple[int], get_placements: List[Union[Replicate, Shard]], - strategy: ts.TorchStoreStrategy, - use_rdma: bool, + strategy: Type[ts.TorchStoreStrategy], + transport_type: TransportType, ): """Given a "put" mesh shape and a "get" mesh shape. 1. Create separate worlds for each mesh shape, running on different devices /PGs. @@ -183,8 +183,6 @@ async def _test_resharding( # Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3] """ - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" - put_world_size = math.prod(put_mesh_shape) get_world_size = math.prod(get_mesh_shape) assert ( @@ -206,7 +204,9 @@ async def _test_resharding( ) # 8x8 square, with ([[0...7],[8...15],[...]]) await ts.initialize( num_storage_volumes=put_world_size if strategy is not None else 1, - strategy=strategy, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, ) with tempfile.TemporaryDirectory() as filesystem_store_dir: # each actor mesh represents a group of processes. diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index caf3c53..9c0c028 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -167,9 +167,7 @@ async def do_get(self): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_state_dict(strategy_params, use_rdma): - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" - +async def test_state_dict(strategy_params, transport_type): class Trainer(Actor): # Monarch RDMA does not work outside of an actor, so we need # to wrapp this test first @@ -200,7 +198,12 @@ async def do_test(self): return state_dict, fetched_state_dict _, strategy = strategy_params - await ts.initialize(num_storage_volumes=1, strategy=strategy) + await ts.initialize( + num_storage_volumes=1, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) trainer = await spawn_actors(1, Trainer, "trainer") try: state_dict, fetched_state_dict = await trainer.do_test.call_one() @@ -212,8 +215,7 @@ async def do_test(self): @pytest.mark.skip("TODO(kaiyuan-li@): fix this test") @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_dcp_sharding_parity(strategy_params, use_rdma): - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" +async def test_dcp_sharding_parity(strategy_params, transport_type): for save_mesh_shape, get_mesh_shape in [ ((2,), (4,)), diff --git a/tests/test_store.py b/tests/test_store.py index 5c94300..736b101 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -31,9 +31,8 @@ @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_basic(strategy_params, use_rdma): +async def test_basic(strategy_params, transport_type): """Test basic put/get functionality for multiple processes""" - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" class PutGetActor(Actor): """Each instance of this actor represents a single process.""" @@ -60,7 +59,12 @@ async def get(self, rank_offset=0): return await ts.get(f"key_{other_rank}") volume_world_size, strategy = strategy_params - await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + await ts.initialize( + num_storage_volumes=volume_world_size, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) # each actor mesh represents a group of processes. actor_mesh_0 = await spawn_actors( volume_world_size, PutGetActor, "actor_mesh_0", world_size=volume_world_size @@ -91,9 +95,8 @@ async def get(self, rank_offset=0): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_objects(strategy_params, use_rdma): +async def test_objects(strategy_params, transport_type): """Test put/get on arbitrary object""" - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" class ObjectActor(Actor): """Each instance of this actor represents a single process.""" @@ -118,7 +121,12 @@ async def get(self, rank_offset=0): return await ts.get(f"key_{other_rank}") volume_world_size, strategy = strategy_params - await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + await ts.initialize( + num_storage_volumes=volume_world_size, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) # each actor mesh represents a group of processes. actor_mesh_0 = await spawn_actors( volume_world_size, ObjectActor, "actor_mesh_0", world_size=volume_world_size @@ -154,9 +162,8 @@ def __eq__(self, other: object) -> bool: @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_exists(strategy_params, use_rdma): +async def test_exists(strategy_params, transport_type): """Test the exists() API functionality""" - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" class ExistsTestActor(Actor): """Actor for testing exists functionality.""" @@ -177,7 +184,12 @@ async def exists(self, key): return await ts.exists(key) volume_world_size, strategy = strategy_params - await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + await ts.initialize( + num_storage_volumes=volume_world_size, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) # Spawn test actors actor_mesh = await spawn_actors( @@ -222,9 +234,8 @@ async def exists(self, key): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_delete(strategy_params, use_rdma): +async def test_delete(strategy_params, transport_type): """Test the delete() API functionality""" - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" class DeleteTestActor(Actor): """Actor for testing delete functionality.""" @@ -253,7 +264,12 @@ async def get(self, key): return await ts.get(key) volume_world_size, strategy = strategy_params - await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + await ts.initialize( + num_storage_volumes=volume_world_size, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) # Spawn test actors actor_mesh = await spawn_actors( @@ -303,9 +319,8 @@ async def get(self, key): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_get_tensor_slice(strategy_params, use_rdma): +async def test_get_tensor_slice(strategy_params, transport_type): """Test tensor slice API functionality""" - os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" class TensorSlicePutActor(Actor): """Actor for putting tensors.""" @@ -322,7 +337,12 @@ async def put(self, key, tensor): await ts.put(key, tensor) volume_world_size, strategy = strategy_params - await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + await ts.initialize( + num_storage_volumes=volume_world_size, + strategy=strategy(transport_type=transport_type) + if strategy is not None + else None, + ) # Spawn test actors - separate meshes for put and get to test cross-process communication put_actor_mesh = await spawn_actors( @@ -405,7 +425,7 @@ class LargeTensorActor(Actor): step_size: int = 100 # -> 400mb max_step: int = 600 # 4mb -> 2gb - def __init__(self, generate_benchmark=False) -> None: + def __init__(self, generate_benchmark=True) -> None: self.generate_benchmark = generate_benchmark init_logging() diff --git a/tests/utils.py b/tests/utils.py index 77c8006..f5719f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,8 @@ from monarch.actor import Actor, current_rank, endpoint from torch.distributed._tensor import distribute_tensor from torch.distributed.device_mesh import init_device_mesh +from torchstore.transport import TransportType +from torchstore.transport.buffers import monarch_rdma_available logger = getLogger(__name__) @@ -26,17 +28,17 @@ def main(file): def transport_plus_strategy_params(): strategies = [ - (2, ts.LocalRankStrategy()), + (2, ts.LocalRankStrategy), (1, None), # ts.SingletonStrategy - (1, ts.ControllerStorageVolumes()), + (1, ts.ControllerStorageVolumes), ] - rdma_options = ( - [True, False] - if os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1" - else [False] - ) - return "strategy_params, use_rdma", list(product(strategies, rdma_options)) + transport_types = list(TransportType) + if not monarch_rdma_available(): + print("Removing rdma tests since rdma is not available") + transport_types.remove(TransportType.MonarchRDMA) + + return "strategy_params, transport_type", list(product(strategies, transport_types)) class DTensorActor(Actor): diff --git a/torchstore/client.py b/torchstore/client.py index e7471ef..a15a3ce 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -11,8 +11,9 @@ import torch from torch.distributed.tensor import DTensor -from torchstore.controller import ObjectType +from torchstore.controller import Controller, ObjectType from torchstore.logging import LatencyTracker +from torchstore.strategy import TorchStoreStrategy from torchstore.transport import Pipe, Request, TensorSlice from torchstore.utils import assemble_global_tensor, get_local_tensor @@ -24,11 +25,7 @@ class LocalClient: is handled by the client. """ - def __init__( - self, - controller, - strategy, - ): + def __init__(self, controller: Controller, strategy: TorchStoreStrategy): self._controller = controller self.strategy = strategy @@ -41,7 +38,6 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): # it will never be dynamic. e.g. it's always based on the # TorchstoreStrategy defined during intiailization storage_volume, volume_id = self.strategy.select_storage_volume() - pipe = Pipe(storage_volume) await pipe.put_to_storage_volume(key, request) @@ -220,14 +216,17 @@ async def _get_stored_object_type(self, key: str) -> ObjectType | None: volume_map = await self._controller.locate_volumes.call_one(key) for storage_info in volume_map.values(): return storage_info.object_type + raise ValueError(f"Unable to get stored object type for key `{key}`") async def _get_object(self, key: str): volume_map = await self._controller.locate_volumes.call_one(key) volume_id, _ = volume_map.popitem() + storage_volume = self.strategy.get_storage_volume(volume_id) pipe = Pipe(storage_volume) request = Request.from_any(None) + return await pipe.get_from_storage_volume(key, request) async def _get_tensor(self, key: str) -> torch.Tensor: @@ -243,6 +242,8 @@ async def _get_tensor(self, key: str) -> torch.Tensor: request = Request.from_any(None) return await pipe.get_from_storage_volume(key, request) + raise ValueError(f"Unable to get tensor for key `{key}`") + async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: """Fetches slices from all volume storages and stitch together to return the whole tensor""" diff --git a/torchstore/controller.py b/torchstore/controller.py index 32615b3..c5a5d53 100644 --- a/torchstore/controller.py +++ b/torchstore/controller.py @@ -6,15 +6,16 @@ from dataclasses import dataclass, field from enum import auto, Enum -from typing import Dict, List, Mapping, Optional, Set +from typing import Dict, List, Mapping, Optional, Set, TYPE_CHECKING from monarch.actor import Actor, endpoint from torchstore.storage_utils.trie import Trie -from torchstore.storage_volume import StorageVolume from torchstore.strategy import TorchStoreStrategy from torchstore.transport.pipe import Request, TensorSlice +if TYPE_CHECKING: + from torchstore.storage_volume import StorageVolume # TODO: move this into request as a field class ObjectType(Enum): @@ -52,7 +53,7 @@ def __init__( self.keys_to_storage_volumes = Trie() self.is_initialized: bool = False self.strategy: Optional[TorchStoreStrategy] = None - self.storage_volumes: Optional[StorageVolume] = None + self.storage_volumes: Optional["StorageVolume"] = None self.num_storage_volumes: Optional[int] = None self.strategy: Optional[TorchStoreStrategy] = None @@ -66,7 +67,7 @@ async def init( self, strategy: TorchStoreStrategy, num_storage_volumes: int, - storage_volumes: StorageVolume, + storage_volumes: "StorageVolume", ) -> None: if self.is_initialized: raise RuntimeError("TorchStore is already initialized") diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 3a87a9f..fe21252 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,11 +11,13 @@ import torch from monarch.actor import Actor, endpoint +from torchstore.logging import init_logging from torchstore.transport.buffers import TransportBuffer - from torchstore.transport.pipe import Request, TensorSlice + from torchstore.utils import assemble_global_tensor, spawn_actors + logger = getLogger(__name__) @@ -31,8 +33,10 @@ def __init__( self, id_func, ) -> None: + init_logging() self.store: StorageImpl = InMemoryStore() self.volume_id: str = id_func() + self.transport_context = {} @classmethod async def spawn( @@ -56,12 +60,19 @@ async def get_id(self) -> str: async def put( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> None: + # something like + # transport_buffer.set_context(self.transport_context) + transport_buffer.transport_context = self.transport_context + transport_buffer.remote_rank = 0 await self.store.put(key, transport_buffer, request) @endpoint async def get( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> TransportBuffer: + # transport_buffer.set_context(self.transport_context) + transport_buffer.transport_context = self.transport_context + transport_buffer.remote_rank = 0 return await self.store.get(key, transport_buffer, request) @endpoint @@ -72,9 +83,16 @@ async def get_meta( ) -> Union[Tuple[torch.Size, torch.dtype], str]: return await self.store.get_meta(key, request) + @endpoint async def delete(self, key: str) -> None: await self.store.delete(key) + @endpoint + async def setup_comms(self, transport_buffer) -> None: + logger.info("Initiating handshake on volume side") + await transport_buffer.storage_volume_setup_comms(self.transport_context) + logger.info("Finished initiating handshake on volume side") + class StorageImpl: """Abstract base class for storage implementations.""" @@ -194,6 +212,7 @@ async def put( # since we pass tensor=None to the transport buffer, # we allocate on the fly tensor = await transport_buffer.read_into(tensor=None) + await transport_buffer.finish() if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return @@ -216,6 +235,7 @@ async def get( if request.tensor_slice is None: await transport_buffer.write_from(self.kv[key]) + await transport_buffer.finish() return transport_buffer # TODO: @@ -223,10 +243,10 @@ async def get( # but this goes entire the value prop of torchstore. StorageVolume must # support requesting a subset of the regions which exist locally in the # store. - for shard in self.kv[key].values(): if shard["slice"] == request.tensor_slice: await transport_buffer.write_from(shard["tensor"]) + await transport_buffer.finish() return transport_buffer raise RuntimeError(f"Tensor slice {request.tensor_slice} not found in {key}") @@ -264,7 +284,6 @@ async def get_meta( f"Could not find shard slice with {request.tensor_slice=} Slices:{stored_object}" ) - raise RuntimeError(f"Unknown type for {key} type={type(val)} {val=}") raise RuntimeError(f"Unknown type for {key} type={type(val)}") async def delete(self, key: str) -> None: diff --git a/torchstore/strategy.py b/torchstore/strategy.py index 4c6bd1a..d9c4231 100644 --- a/torchstore/strategy.py +++ b/torchstore/strategy.py @@ -11,10 +11,34 @@ """ import os +from typing import Any, Dict, TYPE_CHECKING from monarch.actor import current_rank -from torchstore.storage_volume import StorageVolume +from torchstore.transport import TransportType + +if TYPE_CHECKING: + from torchstore.storage_volume import StorageVolume + + +class StorageVolumeRef: + def __init__( + self, + volume: "StorageVolume", + volume_id: str, + transport_type: "TransportType", + transport_context: Dict, + ): + self.volume = volume + self.volume_id = volume_id + self.transport_type = transport_type + # useful for caching transport objects that should survive the lifetime of the client/volume + self.transport_context = transport_context + + def __getattr__(self, name: str) -> Any: + if name not in ("volume_id", "transport_type"): + return getattr(self.volume, name) + return super().__getattribute__(name) class TorchStoreStrategy: @@ -28,9 +52,11 @@ class TorchStoreStrategy: Subclasses must implement get_volume_id() and get_client_id() methods. """ - def __init__(self): + def __init__(self, transport_type: TransportType = TransportType.TorchDistributed): self.storage_volumes = None self.volume_id_to_coord = {} + self.transport_type = transport_type + self.transport_context = {} def __str__(self) -> str: storage_vol_len = ( @@ -84,7 +110,7 @@ def select_storage_volume(self): client_id, ) # client_id == volume_id for this strategy - def get_storage_volume(self, volume_id: str) -> StorageVolume: + def get_storage_volume(self, volume_id: str) -> StorageVolumeRef: """Retrieves storage volume actor for a given volume ID. Args: @@ -94,7 +120,12 @@ def get_storage_volume(self, volume_id: str) -> StorageVolume: StorageVolume: The storage volume actor for the given ID. """ volume_coord = self.volume_id_to_coord[volume_id] - return self.storage_volumes.slice(**volume_coord) + return StorageVolumeRef( + self.storage_volumes.slice(**volume_coord), + volume_id, + self.transport_type, + self.transport_context, + ) class SingletonStrategy(TorchStoreStrategy): @@ -205,5 +236,10 @@ def select_storage_volume(self): client_id, ) # client_id == volume_id for this strategy - def get_storage_volume(self, volume_id: str) -> StorageVolume: - return self.storage_volumes + def get_storage_volume(self, volume_id: str) -> StorageVolumeRef: + return StorageVolumeRef( + self.storage_volumes, + volume_id, + self.transport_type, + self.transport_context, + ) diff --git a/torchstore/transport/__init__.py b/torchstore/transport/__init__.py index c5e09dc..03d7d2c 100644 --- a/torchstore/transport/__init__.py +++ b/torchstore/transport/__init__.py @@ -4,6 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import auto, Enum + +from torchstore.transport.buffers import MonarchTransportBuffer, RDMATransportBuffer + from torchstore.transport.pipe import Pipe, Request, TensorSlice +from torchstore.transport.torch_distributed_buffer import TorchDistributedBuffer + + +class TransportType(Enum): + MonarchRPC = auto() + MonarchRDMA = auto() + TorchDistributed = auto() + + def buffer_cls(self): + return { + TransportType.MonarchRPC: MonarchTransportBuffer, + TransportType.MonarchRDMA: RDMATransportBuffer, + TransportType.TorchDistributed: TorchDistributedBuffer, + }[self] + -__all__ = ["Pipe", "Request", "TensorSlice"] +__all__ = ["Pipe", "Request", "TensorSlice", "TransportType"] diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 12b921b..e9d7be4 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -51,6 +51,7 @@ class TransportBuffer: is_object: bool = False objects: Optional[Any] = None requires_meta: bool = False + read_ahead: bool = False def update(self, other_buffer: "TransportBuffer") -> None: self.finalize = other_buffer.finalize @@ -58,6 +59,16 @@ def update(self, other_buffer: "TransportBuffer") -> None: self.objects = other_buffer.objects self.requires_meta = other_buffer.requires_meta + async def setup_comms(self, storage_volume) -> None: + """Initiate comms handshake with storage_volume""" + pass + + async def storage_volume_setup_comms( + self, transport_context: Dict[str, Any] + ) -> None: + """Mirror of setup_comms, but run on the storage volume side""" + raise NotImplementedError("Must implement storage_volume_setup_comms") + def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: """Allocates internal buffers based on either an existing tensor or a Tuple of (shape, dtype) @@ -70,6 +81,10 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: async def write_from(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError() + await def finish(self) -> None: + """Finalize the transport buffer""" + pass + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..1aea3ee 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -13,12 +13,8 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset -from torchstore.transport.buffers import ( - MonarchTransportBuffer, - rdma_available, - RDMATransportBuffer, - TransportBuffer, -) +from torchstore.logging import LatencyTracker +from torchstore.transport.buffers import TransportBuffer logger = getLogger(__name__) @@ -134,16 +130,19 @@ class Pipe: def __init__(self, storage_volume) -> None: self.storage_volume = storage_volume - def create_transport_buffer(self) -> TransportBuffer: + async def create_transport_buffer(self) -> TransportBuffer: # TODO: eventually this should be dependent on the connections available to a storage_volume - if rdma_available(): - buffer_cls = RDMATransportBuffer - else: - buffer_cls = MonarchTransportBuffer - return buffer_cls() - async def put_to_storage_volume(self, key, request: Request): - transport_buffer = self.create_transport_buffer() + buffer_cls = self.storage_volume.transport_type.buffer_cls() + buffer = buffer_cls() + await buffer.setup_comms(self.storage_volume) + + return buffer + + async def put_to_storage_volume(self, key, request: Request) -> None: + latency_trcker = LatencyTracker(f"put_to_storage_volume:{key}") + + transport_buffer = await self.create_transport_buffer() tensor = request.tensor_val transport_buffer.allocate(tensor) @@ -155,9 +154,13 @@ async def put_to_storage_volume(self, key, request: Request): key, transport_buffer, request.meta_only() ) - async def get_from_storage_volume(self, key, request: Request): + transport_buffer.finish() + latency_trcker.track_step("finish") + latency_trcker.track_e2e() - transport_buffer = self.create_transport_buffer() + async def get_from_storage_volume(self, key, request: Request): + latency_trcker = LatencyTracker(f"get_from_storage_volume:{key}") + transport_buffer = await self.create_transport_buffer() # Certain buffers (RDMA) need to know the size of the tensor # so we can allocate the right amount of memory locally. @@ -169,6 +172,13 @@ async def get_from_storage_volume(self, key, request: Request): else: transport_buffer.allocate(request.tensor_val) + latency_trcker.track_step("allocate") + + # TODO: re-evaluate thiss logic for better polymorphism + t = None + if transport_buffer.read_ahead: + t = await transport_buffer.read_into(request.tensor_val) + # TODO: consider placing the buffer inside the request or vice versa transport_buffer.update( await self.storage_volume.get.call_one( @@ -179,4 +189,11 @@ async def get_from_storage_volume(self, key, request: Request): if transport_buffer.is_object: return transport_buffer.objects + if transport_buffer.read_ahead: + assert ( + t is not None + ), "transport_buffer read ahead is true but no tensor to return" + transport_buffer.finish() + return t + return await transport_buffer.read_into(request.tensor_val) diff --git a/torchstore/transport/torch_distributed_buffer.py b/torchstore/transport/torch_distributed_buffer.py new file mode 100644 index 0000000..cee08d7 --- /dev/null +++ b/torchstore/transport/torch_distributed_buffer.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import timedelta +from logging import getLogger +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch.distributed import ProcessGroup, ProcessGroupGloo, Store + +from torchstore.transport.buffers import TransportBuffer + +# from torchstore.strategy import StorageVolumeRef + +local_pgs: Dict[str, torch.distributed.ProcessGroup] = {} +file_store_names: Dict[str, str] = {} + +logger = getLogger(__name__) + + +def _gloo_factory( + store: Store, + rank: int, + world_size: int, + timeout: timedelta, + device: Optional[torch.device] = None, + **kwargs: object, +) -> ProcessGroup: + """ + From: + https://github.com/pytorch/pytorch/blob/92284fb2ff44f09a9c7df0d8cf6cac9903e376a4/torch/distributed/_dist2.py#L64 + + """ + + assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" + + backend_class = ProcessGroupGloo(store, rank, world_size, timeout) + backend_class._set_sequence_number_for_group() + + pg = ProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + + # register devices + if device is not None: + pg._register_backend(device, ProcessGroup.BackendType.GLOO, backend_class) + + pg._register_backend( + torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class + ) + if torch.cuda.is_available(): + pg._register_backend( + torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class + ) + return pg + + +class TorchDistributedBuffer(TransportBuffer): + requires_meta: bool = True + read_ahead: bool = True + + def __init__(self) -> None: + self.shape: Optional[torch.Size] = None + self.dtype: Optional[torch.dtype] = None + self.fut: Optional[torch.futures.Future] = None + + def __getstate__(self) -> Dict[str, Any]: + # Any time that we serialize the transport buffer, the idea is + # that tensors will be transported via tensor_enginer.RDMABuffer, so it makes + # no sense to hold this reference when we are serializing + state = self.__dict__.copy() + state["fut"] = None + state["transport_context"] = None + return state + + # TODO: ensure this is only called once + async def setup_comms(self, storage_volume): + + # transport context is actually stored in the strategy, + # but is passed along here so we can cache PG's. + + # TODO: file store name is wrong + if storage_volume.volume_id not in file_store_names: + # TODO: TCPStore + file_store_name = f"/tmp/lpasqualin/comms_test{str(uuid.uuid4())[:8]}" + logger.info( + f"Initiating pg handshake with StorageVolume:[{storage_volume.volume_id}]" + f" using id={file_store_name}" + ) + + self.file_store_name = file_store_name + handshake_fut = storage_volume.setup_comms.call(self) + + try: + file_store = torch.distributed.FileStore(file_store_name, 2) + pg = _gloo_factory( + store=file_store, + rank=0, + world_size=2, + timeout=timedelta(seconds=120), + device=torch.device("cpu"), + ) + file_store_names[storage_volume.volume_id] = file_store_name + storage_volume.transport_context[file_store_name] = pg + finally: + await handshake_fut + + logger.info( + f"Finished pg handshake with StorageVolume:[{storage_volume.volume_id}]" + f" using id={file_store_name}" + ) + + self.file_store_name = file_store_names[storage_volume.volume_id] + self.transport_context = storage_volume.transport_context + self.remote_rank = 1 + + async def storage_volume_setup_comms( + self, transport_context: Dict[str, Any] + ) -> None: + + if self.file_store_name in transport_context: + return + + file_store = torch.distributed.FileStore(self.file_store_name, 2) + pg = _gloo_factory( + store=file_store, + rank=1, + world_size=2, + timeout=timedelta(seconds=120), + device=torch.device("cpu"), + ) + transport_context[self.file_store_name] = pg + + def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: + """Allocates internal buffers based on either an existing tensor + or a Tuple of (shape, dtype) + """ + if isinstance(tensor_like, str) or tensor_like is None: + # tensor is just an object, nothing to alloctest + self.is_object = True + return + elif isinstance(tensor_like, Tuple): + # we know the size of the tensor from fetching metadata + self.shape = tensor_like[0] + self.dtype = tensor_like[1] + else: + # we have an inplace tensor, allocate a copy + assert isinstance(tensor_like, torch.Tensor) + self.shape = tensor_like.shape + self.dtype = tensor_like.dtype + + # send + async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.is_object: + return + + if tensor is None: + tensor = torch.empty(self.shape, dtype=self.dtype) + + assert self.fut is None + pg = self.transport_context[self.file_store_name] + self.fut = pg.recv([tensor], srcRank=self.remote_rank, tag=0) + + return tensor + + # recv + async def write_from(self, tensor: Optional[torch.Tensor]) -> None: + if self.is_object: + return + + assert self.fut is None + pg = self.transport_context[self.file_store_name] + self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0) + + async def finish(self): + if self.fut is not None: + while not self.fut.done(): + await asyncio.sleep(0.005) diff --git a/torchstore/utils.py b/torchstore/utils.py index 8c07c9a..1bc334f 100644 --- a/torchstore/utils.py +++ b/torchstore/utils.py @@ -26,9 +26,7 @@ async def spawn_actors(num_processes, actor_cls, name, mesh=None, **init_args): logger.debug("Spawning actors on the local host") mesh = this_host().spawn_procs(per_host={"gpus": num_processes}) await mesh.initialized - actors = await mesh.spawn( - f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args - ) + actors = mesh.spawn(f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args) return actors assert isinstance(mesh, ProcMesh)