diff --git a/example/dtensor.py b/example/dtensor.py index e2db226..8c301a7 100644 --- a/example/dtensor.py +++ b/example/dtensor.py @@ -122,9 +122,9 @@ async def dtensor_put_get_example(): puts it with Shard(0) and gets it with Shard(1). """ # Configuration variables - size = 3 # 100 unit size => 2.4 MB Tensor Size - n_put_actors = 8 - n_get_actors = 8 + size = 1 # 100 unit size => 2.4 MB Tensor Size + n_put_actors = 2 + n_get_actors = 1 print("Starting DTensor put/get example with:") print(f" size = {size}") diff --git a/tests/test_models.py b/tests/test_models.py index 680cf7e..a087e3d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -94,8 +94,8 @@ def build_model(self): def rlog(self, msg): print(f"rank: {self.rank} {msg}") - self.logger.info(f"rank: {self.rank} {msg}") - logger.info(f"rank: {self.rank} {msg}") + # self.logger.info(f"rank: {self.rank} {msg}") + # logger.info(f"rank: {self.rank} {msg}") @endpoint async def do_push(self): @@ -123,19 +123,38 @@ async def do_get(self): if self.world_size > 1: torch.distributed.barrier() + import time + + t = time.perf_counter() + await ts.get_state_dict("v0", state_dict) + self.rlog(f"BEFORE defrag got state dict in {time.perf_counter() - t} seconds") + keys = await ts.keys("v0/model") + + start = time.perf_counter() + count = 0 + for key in keys: + try: + await ts.defrag(key) + count += 1 + except Exception as e: + print(f"Exception in defrag for key {key}: {e}") + continue + + self.rlog(f"defrag {count} keys took {time.perf_counter() - start} seconds") + self.rlog("getting state dict") t = time.perf_counter() await ts.get_state_dict("v0", state_dict) - self.rlog(f"got state dict in {time.perf_counter() - t} seconds") + self.rlog(f"AFTER defrag got state dict in {time.perf_counter() - t} seconds") @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio async def test_basic(strategy_params, use_rdma): # FSDP - put_mesh_shape = (1,) + put_mesh_shape = (8,) get_mesh_shape = (1,) - await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma) + await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], True) @pytest.mark.parametrize(*transport_plus_strategy_params()) diff --git a/tests/utils.py b/tests/utils.py index 4f5d7e2..0f4c08a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,9 +28,9 @@ def main(file): def transport_plus_strategy_params(): strategies = [ - (2, ts.LocalRankStrategy()), + # (2, ts.LocalRankStrategy()), (1, None), # ts.SingletonStrategy - (1, ts.ControllerStorageVolumes()), + # (1, ts.ControllerStorageVolumes()), ] rdma_options = ( [True, False] diff --git a/torchstore/__init__.py b/torchstore/__init__.py index c2ada94..2911f67 100644 --- a/torchstore/__init__.py +++ b/torchstore/__init__.py @@ -9,6 +9,7 @@ from torchstore.api import ( client, + defrag, delete, exists, get, @@ -45,6 +46,7 @@ "put", "get", "delete", + "defrag", "keys", "exists", "client", diff --git a/torchstore/api.py b/torchstore/api.py index 7f3d0ac..802d0c0 100644 --- a/torchstore/api.py +++ b/torchstore/api.py @@ -207,6 +207,19 @@ async def get( return await cl.get(key, inplace_tensor, tensor_slice_spec) +async def defrag(key: str) -> None: + """Perform a defragmentation pass on the distributed store. + + This method triggers a defragmentation pass on all storage volumes. It is not necessary to call this + method manually, as it is called automatically by the controller when necessary. + + Example: + >>> await defrag() + """ + cl = await client() + return await cl.defrag(key) + + async def delete( key: str, *, diff --git a/torchstore/client.py b/torchstore/client.py index d2acb83..6894446 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -19,6 +19,38 @@ logger = getLogger(__name__) +def convert_to_single_shard_request(full_tensor: torch.Tensor) -> Request: + """ + Convert a full tensor to a Request object that represents it as a single-shard DTensor. + + This creates a Request with proper tensor_val and tensor_slice fields to represent + the full tensor as if it were a DTensor with a single shard containing all the data. + This avoids the overhead of actually creating a DTensor with distributed initialization. + + Args: + full_tensor: The complete assembled tensor + + Returns: + Request: A Request object representing a single-shard DTensor + """ + # Create a tensor slice that represents the entire tensor as a single shard + tensor_slice = TensorSlice( + offsets=(0,) * len(full_tensor.shape), # Start at origin for all dimensions + coordinates=(0,), # Single device at coordinate (0,) + global_shape=full_tensor.shape, # Global shape is the full tensor shape + local_shape=full_tensor.shape, # Local shape equals global (single shard) + mesh_shape=(1,), # Single device mesh + ) + + # Create and return the Request object + return Request( + tensor_val=full_tensor, + tensor_slice=tensor_slice, + objects=None, + is_object=False, + ) + + class LocalClient: """This class represents the local store, which exists on every process. Remote storage is handled by the client. @@ -96,6 +128,10 @@ async def get( Request.from_any(inplace_tensor).tensor_slice or tensor_slice_spec ) # Here full tensor should be the part of interest. + if key == "v0/model.model.norm.weight": + print( + f"\033[92mgetting tensor slice {tensor_slice} for key {key}\033[0m" + ) fetched_tensor = await self._get_and_assemble_tensor(key, tensor_slice) # Pipe does not have support for inplace copies of fetched tensors yet, @@ -128,6 +164,24 @@ async def keys(self, prefix: str | None = None) -> list[str]: # Keys are synced across all storage volumes, so we just call one. return await self._controller.keys.call_one(prefix) + async def defrag(self, key: str) -> None: + # check if stored key is a tensor slice, return if not. + stored_object_type = await self._get_stored_object_type(key) + + if stored_object_type is not ObjectType.TENSOR_SLICE: + raise ValueError( + f"Cannot defragment for key `{key}` because value type is {stored_object_type}, expect TENSOR_SLICE" + ) + + # Put the single-shard representation back to storage + storage_volume, volume_id = self.strategy.select_storage_volume() + await self._controller.notify_delete.call_one(key, volume_id) + tensor_slice = await storage_volume.defrag.call_one(key) + if key == "v0/model.model.norm.weight": + print(f"tensor_slice: {tensor_slice}") + request = Request.from_tensor_slice(tensor_slice) + await self._controller.notify_put.call(key, request, volume_id) + async def delete(self, key: str) -> None: """ Delete a key from the distributed store. @@ -266,10 +320,15 @@ async def _get_and_assemble_tensor( The assembled tensor from all storage volumes """ volume_map = await self._locate_volumes(key) + if key == "v0/model.model.norm.weight": + print(f"\033[92mvolume map for key {key}: {volume_map}\033[0m") # Handle the tensor case partial_results = [] for volume_id, storage_info in volume_map.items(): storage_volume = self.strategy.get_storage_volume(volume_id) + if key == "v0/model.model.norm.weight": + print(f"storage volume: {storage_volume}") + # print(f"stored val: {storage_volume.store.kv[key]}") pipe = Pipe(storage_volume) # fetch from all storage volumes, something like this diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 71fc41e..63e967f 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -64,6 +64,10 @@ async def get( ) -> TransportBuffer: return await self.store.get(key, transport_buffer, request) + @endpoint + async def defrag(self, key: str) -> None: + return await self.store.defrag(key) + @endpoint async def get_meta( self, @@ -96,6 +100,10 @@ async def get( """Retrieve data from the storage backend.""" raise NotImplementedError() + async def defrag(self, key: str) -> None: + """Defragment tensor slices into just one tensor slice""" + raise NotImplementedError() + async def get_meta( self, key: str, request: Optional[Request] = None ) -> Union[Tuple[torch.Size, torch.dtype], str]: @@ -189,6 +197,44 @@ def _handle_dtensor( "tensor": tensor, } + async def defrag(self, key: str) -> TensorSlice: + # get local tensors, global shape and global offsets from kv[key] + local_tensors = [] + global_offsets = [] + global_shape = None + for shard in self.kv[key].values(): + + local_tensors.append(shard["tensor"]) + tensor_shard = shard["slice"] + + global_offsets.append(tensor_shard.offsets) + if global_shape is None: + global_shape = tensor_shard.global_shape + else: + assert global_shape == tensor_shard.global_shape + + full_tensor = assemble_tensor( + local_tensors, + global_shape, + global_offsets, + ) + + # convert assembled tensor to a single shard tensor and store it in kv[key] + tensor_slice = TensorSlice( + offsets=(0,) * len(full_tensor.shape), # Start at origin for all dimensions + coordinates=(0,), # Single device at coordinate (0,) + global_shape=full_tensor.shape, # Global shape is the full tensor shape + local_shape=full_tensor.shape, # Local shape equals global (single shard) + mesh_shape=(1,), # Single device mesh + ) + self.kv[key] = { + tensor_slice.coordinates: { + "slice": tensor_slice, + "tensor": full_tensor, + } + } + return tensor_slice + async def put( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> None: