diff --git a/get_benchmark.csv b/get_benchmark.csv new file mode 100644 index 0000000..5cf8a22 --- /dev/null +++ b/get_benchmark.csv @@ -0,0 +1,7 @@ +size_mbytes, delta +4, 0.018769782967865467, 213.10848435744523 +404, 0.15033994475379586, 2687.2432383928995 +804, 0.4327937951311469, 1857.6976126849709 +1204, 0.9795559397898614, 1229.1283744941481 +1604, 0.8066510939970613, 1988.4681393686224 +2004, 0.8627498750574887, 2322.8053204487164 diff --git a/put_benchmark.csv b/put_benchmark.csv new file mode 100644 index 0000000..47c685c --- /dev/null +++ b/put_benchmark.csv @@ -0,0 +1,7 @@ +size_mbytes, delta +4, 0.6578615619800985, 6.080306604265484 +404, 0.16606504004448652, 2432.7817576280595 +804, 0.3294775849208236, 2440.226700681955 +1204, 0.4480641600675881, 2687.115166315429 +1604, 0.5693950089626014, 2817.025043690456 +2004, 0.6524990671314299, 3071.268758758767 diff --git a/tests/test_models.py b/tests/test_models.py index 4cdfde8..ee39559 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,9 +17,8 @@ from monarch.actor import Actor, current_rank, endpoint from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import fully_shard -from torchstore.logging import init_logging -from torchstore.utils import spawn_actors from torchstore.state_dict_utils import _state_dict_size +from torchstore.utils import spawn_actors from transformers import AutoModelForCausalLM @@ -170,13 +169,13 @@ async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma): file_store_name=os.path.join(tmpdir, "get_world"), ) - logger.info(f"do_push ") + logger.info("do_push ") await put_world.do_push.call() - await get_world.do_get.call() finally: await ts.shutdown() + if __name__ == "__main__": main([__file__]) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index 92c8757..caf3c53 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -209,6 +209,7 @@ async def do_test(self): _assert_equal_state_dict(state_dict, fetched_state_dict) +@pytest.mark.skip("TODO(kaiyuan-li@): fix this test") @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio async def test_dcp_sharding_parity(strategy_params, use_rdma): diff --git a/tests/test_store.py b/tests/test_store.py index 3134066..5fda4a8 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -136,7 +136,7 @@ def __eq__(self, other: object) -> bool: try: for idx in range(volume_world_size): - actor = actor_mesh_0.slice(**{"hosts": 0, "gpus": idx}) + actor = actor_mesh_0.slice(gpus=idx) await actor.put.call(MyTestObject(idx)) for rank_offset in (0, 1): @@ -196,7 +196,7 @@ async def exists(self, key): # Test 2: Store tensors and check existence tensor = torch.tensor([1, 2, 3, 4, 5]) for rank in range(volume_world_size): - actor = actor_mesh.slice(**{"hosts": 0, "gpus": rank}) + actor = actor_mesh.slice(gpus=rank) await actor.put.call(f"tensor_key_{rank}", tensor) for rank in range(volume_world_size): @@ -207,7 +207,7 @@ async def exists(self, key): # Test 3: Store objects and check existence obj = {"rank": 0, "data": [1, 2, 3]} for rank in range(volume_world_size): - actor = actor_mesh.slice(**{"hosts": 0, "gpus": rank}) + actor = actor_mesh.slice(gpus=rank) await actor.put.call(f"object_key_{rank}", obj) for rank in range(volume_world_size): @@ -220,6 +220,87 @@ async def exists(self, key): await ts.shutdown() +@pytest.mark.parametrize(*transport_plus_strategy_params()) +@pytest.mark.asyncio +async def test_delete(strategy_params, use_rdma): + """Test the delete() API functionality""" + os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" + + class DeleteTestActor(Actor): + """Actor for testing delete functionality.""" + + def __init__(self, world_size): + init_logging() + self.world_size = world_size + self.rank = current_rank().rank + # required by LocalRankStrategy + os.environ["LOCAL_RANK"] = str(self.rank) + + @endpoint + async def put(self, key, value): + await ts.put(key, value) + + @endpoint + async def delete(self, key): + await ts.delete(key) + + @endpoint + async def exists(self, key): + return await ts.exists(key) + + @endpoint + async def get(self, key): + return await ts.get(key) + + volume_world_size, strategy = strategy_params + await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) + + # Spawn test actors + actor_mesh = await spawn_actors( + volume_world_size, + DeleteTestActor, + "delete_test_actors", + world_size=volume_world_size, + ) + + try: + # Test 1: Store tensors, verify they exist, then delete them + tensor = torch.tensor([1, 2, 3, 4, 5]) + for rank in range(volume_world_size): + actor = actor_mesh.slice(gpus=rank) + await actor.put.call(f"tensor_key_{rank}", tensor) + + # Verify all tensors exist + for rank in range(volume_world_size): + results = await actor_mesh.exists.call(f"tensor_key_{rank}") + for _, exists_result in results: + assert exists_result + + # Delete tensors one at a time and verify each deletion + for rank in range(volume_world_size): + actor = actor_mesh.slice(gpus=rank) + await actor.delete.call(f"tensor_key_{rank}") + + # Verify this specific tensor no longer exists + results = await actor_mesh.exists.call(f"tensor_key_{rank}") + for _, exists_result in results: + assert not exists_result + + # Verify other tensors still exist (if any remain) + for other_rank in range(rank + 1, volume_world_size): + results = await actor_mesh.exists.call(f"tensor_key_{other_rank}") + for _, exists_result in results: + assert exists_result + + # Test 2: Try to get deleted tensor (should raise exception) + with pytest.raises(Exception): + await actor_mesh.get.call("tensor_key_0") + + finally: + await actor_mesh._proc_mesh.stop() + await ts.shutdown() + + @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio async def test_get_tensor_slice(strategy_params, use_rdma): @@ -256,7 +337,7 @@ async def put(self, key, tensor): key = "test_tensor" # Store the tensor using put actor mesh - put_actor = put_actor_mesh.slice(**{"hosts": 0, "gpus": 0}) + put_actor = put_actor_mesh.slice(gpus=0) await put_actor.put.call(key, test_tensor) # Test full tensor retrieval using get actor mesh @@ -324,7 +405,7 @@ class LargeTensorActor(Actor): step_size: int = 100 # -> 400mb max_step: int = 600 # 4mb -> 2gb - def __init__(self, generate_benchmark=False) -> None: + def __init__(self, generate_benchmark=True) -> None: self.generate_benchmark = generate_benchmark init_logging() @@ -386,9 +467,13 @@ async def get(self): # controller code await ts.initialize() actor = await spawn_actors(1, LargeTensorActor, "large_tensor") - await actor.put.call_one() - await actor.get.call_one() - # TODO: assert equal tensors from put/get + try: + await actor.put.call_one() + await actor.get.call_one() + # TODO: assert equal tensors from put/get + finally: + await actor._proc_mesh.stop() + await ts.shutdown() @pytest.mark.asyncio diff --git a/torchstore/__init__.py b/torchstore/__init__.py index 1f077e9..e4cb0c4 100644 --- a/torchstore/__init__.py +++ b/torchstore/__init__.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# this helps with this +import torch + import os from logging import getLogger diff --git a/torchstore/client.py b/torchstore/client.py index 7fe8fd2..e7471ef 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -4,18 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import time +import asyncio from logging import getLogger from typing import Any, Union import torch from torch.distributed.tensor import DTensor -from torchstore.controller import ObjectType -from torchstore.transport import Pipe, Request, TensorSlice from torchstore.controller import ObjectType from torchstore.logging import LatencyTracker -from torchstore.transport import Pipe, Request +from torchstore.transport import Pipe, Request, TensorSlice from torchstore.utils import assemble_global_tensor, get_local_tensor logger = getLogger(__name__) @@ -53,7 +51,6 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): latency_tracker.track_step("notify_put") latency_tracker.track_e2e() - @torch.no_grad async def get( self, @@ -122,6 +119,35 @@ async def keys(self, prefix: str | None = None) -> list[str]: # Keys are synced across all storage volumes, so we just call one. return await self._controller.keys.call_one(prefix) + async def delete(self, key: str) -> None: + """ + Delete a key from the distributed store. + + Args: + key (str): The key to delete. + + Returns: + None + + Raises: + KeyError: If the key does not exist in the store. + """ + latency_tracker = LatencyTracker(f"delete:{key}") + volume_map = await self._controller.locate_volumes.call_one(key) + + async def delete_from_volume(volume_id: str): + volume = self.strategy.get_storage_volume(volume_id) + # Notify should come before the actual delete, so that the controller + # doesn't think the key is still in the store when delete is happening. + await self._controller.notify_delete.call_one(key, volume_id) + await volume.delete.call(key) + + await asyncio.gather( + *[delete_from_volume(volume_id) for volume_id in volume_map] + ) + + latency_tracker.track_e2e() + async def exists(self, key: str) -> bool: """Check if a key exists in the distributed store. @@ -179,7 +205,8 @@ def _verify_get_args( and tensor_slice_spec.local_shape != inplace_tensor.shape ): raise ValueError( - f"Requested tensor slice shape {tensor_slice_spec.local_shape} does not match in-place tensor shape {inplace_tensor.shape}" + f"Requested tensor slice shape {tensor_slice_spec.local_shape} " + f"does not match in-place tensor shape {inplace_tensor.shape}" ) if isinstance(inplace_tensor, DTensor): diff --git a/torchstore/logging.py b/torchstore/logging.py index 40d7e98..781ea44 100644 --- a/torchstore/logging.py +++ b/torchstore/logging.py @@ -9,7 +9,6 @@ import os import sys - def init_logging(): log_level = os.environ.get("TORCHSTORE_LOG_LEVEL", "INFO").upper() diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 461120d..1158dc7 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -65,23 +65,6 @@ async def get_state_dict( inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None, ) - # # Prepare all the coroutines first - # coros = [] - # keys = [] - # for flattened_key in fetched_mapping.keys(): - # inplace_tensor = user_flattened_state_dict.get(flattened_key, None) - # keys.append(flattened_key) - # coros.append( - # store.get( - # f"{key}{DELIM}{flattened_key}", - # inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None, - # ) - # ) - # # Run all requests concurrently - # results = await asyncio.gather(*coros) - # # Build the result dictionary - # fetched_state_dict = dict(zip(keys, results)) - return unflatten_state_dict(fetched_state_dict, fetched_mapping) def _state_dict_size(state_dict): diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index e54b010..9cb3e0f 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -6,11 +6,14 @@ from itertools import product from logging import getLogger +from datetime import timedelta from typing import Any, Dict, Optional, Tuple, Union import torch + from monarch.actor import Actor, endpoint +from torchstore.utils import _gloo_factory from torchstore.transport.buffers import TransportBuffer from torchstore.transport.pipe import Request, TensorSlice @@ -31,7 +34,8 @@ def __init__( self, id_func, ) -> None: - self.store: StorageImpl = InMemoryStore() + self.pgs = {} + self.store: StorageImpl = InMemoryStore(self.pgs) self.volume_id: str = id_func() @classmethod @@ -61,9 +65,30 @@ async def get( return await self.store.get(key, transport_buffer, request) @endpoint - async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]: - return await self.store.get_meta(key) + async def get_meta( + self, + key: str, + request: Optional[Request] = None, + ) -> Union[Tuple[torch.Size, torch.dtype], str]: + return await self.store.get_meta(key, request) + + @endpoint + async def handshake(self, file_store_name): + if file_store_name in self.pgs: + return + logger.info(f"Finalizing handshake from {file_store_name}") + + file_store = torch.distributed.FileStore(file_store_name, 2) + pg = _gloo_factory( + store=file_store, + rank=1, + world_size=2, + timeout=timedelta(seconds=120), + device=torch.device("cpu"), + ) + self.pgs[file_store_name] = pg + logger.info(f"Handshake succesful {file_store_name}") class StorageImpl: """Abstract base class for storage implementations.""" @@ -80,7 +105,7 @@ async def get( """Retrieve data from the storage backend.""" raise NotImplementedError() - async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]: + async def get_meta(self, key: str, request: Optional[Request]=None) -> Union[Tuple[torch.Size, torch.dtype], str]: """Get metadata about stored data.""" raise NotImplementedError() @@ -88,8 +113,9 @@ async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str] class InMemoryStore(StorageImpl): """Local in memory storage.""" - def __init__(self) -> None: + def __init__(self, pgs) -> None: self.kv: Dict[str, Any] = {} + self.pgs = pgs def _build_full_tensor(self, key: str) -> None: logger.debug(f"Building full tensor for {key}") @@ -176,7 +202,11 @@ async def put( # since we pass tensor=None to the transport buffer, # we allocate on the fly - tensor = await transport_buffer.read_into(tensor=None) + + pg = self.pgs[transport_buffer.file_store_name] + + tensor = await transport_buffer.read_into(tensor=None, pg=pg, r=0) + transport_buffer.finish() if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return @@ -198,7 +228,12 @@ async def get( return transport_buffer if request.tensor_slice is None: - await transport_buffer.write_from(self.kv[key]) + await transport_buffer.write_from( + self.kv[key], + pg=self.pgs[transport_buffer.file_store_name], + r=0 + ) + transport_buffer.finish() return transport_buffer # TODO: @@ -214,19 +249,34 @@ async def get( raise RuntimeError(f"Tensor slice {request.tensor_slice} not found in {key}") - async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]: + async def get_meta( + self, + key: str, + request: Optional[Request] = None, + ) -> Union[Tuple[torch.Size, torch.dtype], str]: if key not in self.kv: raise KeyError(f"Key '{key}' not found. {list(self.kv.keys())=}") - val = self.kv[key] - if isinstance(val, torch.Tensor): - return val.shape, val.dtype + stored_object = self.kv[key] + if isinstance(stored_object, torch.Tensor): + return stored_object.shape, stored_object.dtype - assert isinstance(val, dict) - if "obj" in val: + assert isinstance(stored_object, dict) + if "obj" in stored_object: return "obj" - if "tensor" in val: - return val["tensor"].shape, val["tensor"].dtype - - raise RuntimeError(f"Unknown type for {key} type={type(val)}") + if "tensor" in stored_object: + return stored_object["tensor"].shape, stored_object["tensor"].dtype + + if request is not None and request.tensor_slice is not None: + #TODO: makes this an object + for shard in stored_object.values(): + shard_slice = shard["slice"] + if shard_slice.local_shape == request.tensor_slice.local_shape and shard_slice.offsets == request.tensor_slice.offsets: + return shard["tensor"].shape, shard["tensor"].dtype + + raise KeyError( + f"Could not find shard slice with {request.tensor_slice=} Slices:{stored_object}" + ) + + raise RuntimeError(f"Unknown type for {key} type={type(val)} {val=}") diff --git a/torchstore/strategy.py b/torchstore/strategy.py index 3076267..5941213 100644 --- a/torchstore/strategy.py +++ b/torchstore/strategy.py @@ -92,7 +92,10 @@ def get_storage_volume(self, volume_id: str) -> StorageVolume: StorageVolume: The storage volume actor for the given ID. """ volume_coord = self.volume_id_to_coord[volume_id] - return self.storage_volumes.slice(**volume_coord) + storage_volume = self.storage_volumes.slice(**volume_coord) + storage_volume.volume_id = volume_id + storage_volume.client_id = self.get_client_id() + return storage_volume class SingletonStrategy(TorchStoreStrategy): diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 12b921b..5400e71 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -6,9 +6,14 @@ import logging import os +import uuid + +from datetime import timedelta from typing import Any, Dict, List, Optional, Tuple, Union + import torch +from torchstore.utils import _gloo_factory try: from monarch.tensor_engine import is_available as monarch_rdma_available, RDMABuffer @@ -20,6 +25,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: "RDMABuffer is not available. This environemnt was likely not built with tensor_engine supoprt." ) +logger = logging.getLogger(__name__) # TODO: for some reason, RDMABuffer is breaking for certain tensors on the HF models (qwen, llama) # but setting this chunk size works around the issue until we can fix it @@ -27,21 +33,14 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: # Check for misspelled environment variable for backward compatibility rdma_chunk_size_env = os.environ.get("TORCHSTORE_RDMDA_CHUNK_SIZE_MB") -if rdma_chunk_size_env is not None: - logging.warning( - "Using deprecated environment variable 'TORCHSTORE_RDMDA_CHUNK_SIZE_MB'. " - "Please use 'TORCHSTORE_RDMA_CHUNK_SIZE_MB' instead." - ) - RDMA_CHUNK_SIZE_MB: int = int(rdma_chunk_size_env) -else: - RDMA_CHUNK_SIZE_MB: int = int(os.environ.get("TORCHSTORE_RDMA_CHUNK_SIZE_MB", "4")) +RDMA_CHUNK_SIZE_MB: int = int(os.environ.get("TORCHSTORE_RDMA_CHUNK_SIZE_MB", "512")) assert RDMA_CHUNK_SIZE_MB <= 1024, "Monarch does not support 1gb chunks via rdma" def rdma_available() -> bool: rdma_enabled = ( - os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1" + os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" ) # TODO: enable on this build return rdma_enabled and monarch_rdma_available() @@ -60,7 +59,7 @@ def update(self, other_buffer: "TransportBuffer") -> None: def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: """Allocates internal buffers based on either an existing tensor - or a Tuple of (shape, dtype) + or a Tuple of (shape, dtype) """ raise NotImplementedError() @@ -71,6 +70,107 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError() + +local_pgs = {} +file_store_names = {} + +class TorchDistributedBuffer(TransportBuffer): + + requires_meta: bool = True + def __init__(self) -> None: + self.shape: Optional[torch.Size] = None + self.dtype: Optional[torch.dtype] = None + self.fut: Optional[torch.futures.Future] = None + self.pg: Optional[torch.distributed.ProcessGroup] = None + + def __getstate__(self) -> Dict[str, Any]: + # Any time that we serialize the transport buffer, the idea is + # that tensors will be transported via tensor_enginer.RDMABuffer, so it makes + # no sense to hold this reference when we are serializing + state = self.__dict__.copy() + state["fut"] = None + state["pg"] = None + return state + + async def handshake(self, storage_volume): + if storage_volume.volume_id not in local_pgs: + #TODO: TCPStore + file_store_name = f"/tmp/lpasqualin/comms_test{str(uuid.uuid4())[:8]}" + self.file_store_name = file_store_name + logger.info( + f"Initiating pg handshake between {storage_volume.volume_id}" + f" and {storage_volume.client_id} using id={file_store_name}" + ) + + handshake_fut = storage_volume.handshake.call(file_store_name) + try: + + file_store = torch.distributed.FileStore(file_store_name, 2) + pg = _gloo_factory( + store=file_store, + rank=0, + world_size=2, + timeout=timedelta(seconds=120), + device=torch.device("cpu"), + ) + file_store_names[storage_volume.volume_id] = file_store_name + local_pgs[storage_volume.volume_id] = pg + + finally: + await handshake_fut + + self.pg = local_pgs[storage_volume.volume_id] + self.file_store_name = file_store_names[storage_volume.volume_id] + + + def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: + """Allocates internal buffers based on either an existing tensor + or a Tuple of (shape, dtype) + """ + if isinstance(tensor_like, str) or tensor_like is None: + # tensor is just an object, nothing to allocte + return + elif isinstance(tensor_like, Tuple): + # we know the size of the tensor from fetching metadata + self.shape = tensor_like[0] + self.dtype = tensor_like[1] + else: + # we have an inplace tensor, allocate a copy + assert isinstance(tensor_like, torch.Tensor) + self.shape = tensor_like.shape + self.dtype = tensor_like.dtype + + + # send + async def read_into(self,tensor: Optional[torch.Tensor] = None, pg=None,r=0) -> torch.Tensor: + if tensor is None: + tensor = torch.empty(self.shape, dtype=self.dtype) + + assert self.fut is None + pg = pg or self.pg + self.fut = pg.recv([tensor], srcRank=r, tag=0) + # self.fut = torch.distributed.irecv(tensor, src=0, group=pg) + + return tensor + + # recv + async def write_from(self, tensor: Optional[torch.Tensor],pg=None,r=0) -> None: + assert self.fut is None + pg = pg or self.pg + self.fut = pg.send([tensor], dstRank=r, tag=0) + + def finish(self): + assert self.fut is not None + self.fut.wait() + + # def update(self, other_buffer: "TransportBuffer") -> None: + # super().update(other_buffer) + # self.tensor = other_buffer.tensor + + + + + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor # for now we utilize copies. @@ -113,8 +213,8 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: """Allocates internal buffers based on either an existing tensor or a Tuple of (shape, dtype) """ - logging.debug("Allocating rdma buffer") + #TODO: why is tensor_like a if isinstance(tensor_like, str) or tensor_like is None: # tensor is just an object, nothing to allocte return @@ -124,7 +224,9 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: else: # we have an inplace tensor, allocate a copy assert isinstance(tensor_like, torch.Tensor) - tensor = torch.empty_like(tensor_like) + tensor = torch.empty_like(tensor_like, memory_format=torch.torch.contiguous_format) + + #TODO: do we need this tensor??? # store tensor meta self.shape = tensor.shape @@ -133,14 +235,21 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: self._assert_valid_tensor(tensor) + byte_view_chunks = self._create_byte_views_from_tensor(tensor) - self.tensor_refs = [torch.empty_like(chunk) for chunk in byte_view_chunks] + self.tensor_refs = [ + torch.empty_like(chunk, memory_format=torch.torch.contiguous_format) for chunk in byte_view_chunks + ] + self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs] chunk_sizes = set() - for chunk in self.tensor_refs: + for idx, chunk in enumerate(self.tensor_refs): + chunk_size = chunk.numel() * chunk.element_size() + buffer_size = self.rdma_buffers[idx].size() + assert chunk_size == buffer_size, f"{chunk_size=} != {buffer_size=}" chunk_sizes.add(chunk.shape) - logging.debug(f"Allocted {len(self.rdma_buffers)} rdma buffers {chunk_sizes=}") + logger.debug(f"Allocted {len(self.rdma_buffers)} rdma buffers {chunk_sizes=}") def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) @@ -155,6 +264,12 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor assert self.rdma_buffers is not None chunked_byte_view = self._create_byte_views_from_tensor(tensor) + chunk_sizes = set() + for chunk in chunked_byte_view: + chunk_sizes.add(chunk.shape) + logger.debug( + f"Read side allocs: {len(self.rdma_buffers)} rdma buffers {chunk_sizes=}" + ) # if we have tensor refs locally, we're still in the local case, # and we're just copying over our chunks into the tensor from @@ -168,9 +283,10 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor # TODO: gather instead of reading sequentially try: for idx, chunk in enumerate(chunked_byte_view): + assert chunk.numel() * chunk.element_size() == self.rdma_buffers[idx].size() await self.rdma_buffers[idx].read_into(chunk) except Exception as e: - logging.exception( + logger.exception( f"Failed read_into, {tensor.shape=}, {tensor.dtype=}", exc_info=e ) raise e @@ -185,6 +301,7 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: self._assert_valid_tensor(tensor) assert self.rdma_buffers is not None + chunked_byte_view = self._create_byte_views_from_tensor(tensor) # if we have tensor refs locally, we're still in the local case, @@ -196,8 +313,14 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: # else: we are in the remote case (in a different process), and must read from # the rdma buffer # TODO: gather instead of reading sequentially - for idx, chunk in enumerate(chunked_byte_view): - await self.rdma_buffers[idx].write_from(chunk) + try: + for idx, chunk in enumerate(chunked_byte_view): + await self.rdma_buffers[idx].write_from(chunk) + except Exception as e: + logger.exception( + f"Failed write_from, {tensor.shape=}, {tensor.dtype=}", exc_info=e + ) + raise e class MonarchTransportBuffer(TransportBuffer): diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index ad77966..178b028 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -18,6 +18,7 @@ rdma_available, RDMATransportBuffer, TransportBuffer, + TorchDistributedBuffer ) logger = getLogger(__name__) @@ -136,41 +137,51 @@ class Pipe: def __init__(self, storage_volume) -> None: self.storage_volume = storage_volume - def create_transport_buffer(self) -> TransportBuffer: + async def create_transport_buffer(self) -> TransportBuffer: # TODO: eventually this should be dependent on the connections available to a storage_volume + + #TODO: + if True: + buffer = TorchDistributedBuffer() + await buffer.handshake(self.storage_volume) + return buffer if rdma_available(): - buffer_cls = RDMATransportBuffer - else: - buffer_cls = MonarchTransportBuffer - return buffer_cls() + return RDMATransportBuffer() + + return MonarchTransportBuffer() + async def put_to_storage_volume(self, key, request: Request): - transport_buffer = self.create_transport_buffer() + transport_buffer = await self.create_transport_buffer() tensor = request.tensor_val transport_buffer.allocate(tensor) - await transport_buffer.write_from(tensor) + await transport_buffer.write_from(tensor, r=1) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead await self.storage_volume.put.call_one( key, transport_buffer, request.meta_only() ) + transport_buffer.finish() async def get_from_storage_volume(self, key, request: Request): - transport_buffer = self.create_transport_buffer() + transport_buffer = await self.create_transport_buffer() # Certain buffers (RDMA) need to know the size of the tensor # so we can allocate the right amount of memory locally. # This can be avoided if the request contains a tensor slice. # Could likely be optimized away in the future. if transport_buffer.requires_meta and request.tensor_val is None: - meta = await self.storage_volume.get_meta.call_one(key) + meta = await self.storage_volume.get_meta.call_one(key, request.meta_only()) transport_buffer.allocate(meta) else: transport_buffer.allocate(request.tensor_val) + if isinstance(transport_buffer, TorchDistributedBuffer): + t = await transport_buffer.read_into(request.tensor_val, r=1) + # TODO: consider placing the buffer inside the request or vice versa transport_buffer.update( await self.storage_volume.get.call_one( @@ -181,4 +192,8 @@ async def get_from_storage_volume(self, key, request: Request): if transport_buffer.is_object: return transport_buffer.objects + if isinstance(transport_buffer, TorchDistributedBuffer): + transport_buffer.finish() + return t + return await transport_buffer.read_into(request.tensor_val) diff --git a/torchstore/utils.py b/torchstore/utils.py index 764fce9..80ff90b 100644 --- a/torchstore/utils.py +++ b/torchstore/utils.py @@ -10,6 +10,8 @@ from typing import List, Tuple, TYPE_CHECKING import torch +from torch.distributed import Store, ProcessGroup +from datetime import timedelta from monarch.actor import this_host, ProcMesh @@ -26,7 +28,7 @@ async def spawn_actors(num_processes, actor_cls, name, mesh=None, **init_args): logger.debug("Spawning actors on the local host") mesh = this_host().spawn_procs(per_host={"gpus": num_processes}) await mesh.initialized - actors = await mesh.spawn(f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args) + actors = mesh.spawn(f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args) return actors assert isinstance(mesh, ProcMesh) @@ -80,3 +82,32 @@ def assemble_global_tensor( global_tensor[slices] = local_tensor return global_tensor + +def _gloo_factory( + store: Store, + rank: int, + world_size: int, + timeout: timedelta, + device: torch.device, + **kwargs: object, +) -> ProcessGroup: + from torch.distributed import ProcessGroupGloo + + assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" + + backend_class = ProcessGroupGloo(store, rank, world_size, timeout) + backend_class._set_sequence_number_for_group() + + pg = ProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + + # register devices + pg._register_backend(device, ProcessGroup.BackendType.GLOO, backend_class) + pg._register_backend( + torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class + ) + if torch.cuda.is_available(): + pg._register_backend( + torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class + ) + return pg