diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index caf3c53..8cb1688 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -14,20 +14,27 @@ import pytest import torch + +import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn - import torchstore as ts from monarch.actor import Actor, current_rank, endpoint + from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, ) -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import fully_shard -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Replicate +from torchstore.state_dict_utils import ( + TensorReference, + TORCHSTORE_TSSD_ENABLED_FLAG, + TorchStoreStateDict, +) from torchstore.utils import spawn_actors from .utils import main, transport_plus_strategy_params @@ -38,6 +45,100 @@ MODEL_LINER_LENGTH = 10 +def _setup_process_group(): + """Set up minimal distributed environment for DTensor testing.""" + + if not dist.is_initialized(): + # Set minimal environment variables for single process + import os + + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault( + "MASTER_PORT", "29501" + ) # Different port to avoid conflicts + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + + # Initialize single-process group + dist.init_process_group( + backend="gloo", # CPU backend + rank=0, + world_size=1, + ) + return True + + +def _verify_tensor_references(torchstore_state_dict, flattened_original): + """Utility function to verify TensorReference objects in flattened state dict.""" + for key, original_value in flattened_original.items(): + torchstore_value = torchstore_state_dict.flattened_state_dict[key] + + if isinstance(original_value, torch.Tensor): + if hasattr(original_value, "_local_tensor"): # DTensor check + # DTensor should be converted to TensorReference with tensor_slice + assert isinstance(torchstore_value, TensorReference) + assert ( + torchstore_value.tensor_slice is not None + ), f"DTensor at {key} should have tensor_slice" + assert ( + torchstore_value.device_mesh is not None + ), f"DTensor at {key} should have device_mesh" + assert ( + torchstore_value.placements is not None + ), f"DTensor at {key} should have placements" + + # Verify local tensor metadata + local_tensor = original_value._local_tensor + assert torchstore_value.shape == tuple(local_tensor.shape) + assert torchstore_value.dtype == local_tensor.dtype + else: + # Regular tensor should not have tensor_slice + assert isinstance(torchstore_value, TensorReference) + assert ( + torchstore_value.tensor_slice is None + ), f"Regular tensor at {key} should not have tensor_slice" + assert torchstore_value.shape == tuple(original_value.shape) + assert torchstore_value.dtype == original_value.dtype + + +def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed): + """Utility function to verify reconstructed state dict matches original.""" + for key, original_value in flattened_original.items(): + reconstructed_value = flattened_reconstructed[key] + + if hasattr(original_value, "_local_tensor"): # DTensor check + # Should be reconstructed as DTensor + assert hasattr( + reconstructed_value, "_local_tensor" + ), f"Expected DTensor for {key}" + + # Verify local tensor data matches + assert torch.equal( + original_value._local_tensor, reconstructed_value._local_tensor + ), f"Local tensor data mismatch for {key}" + + # Verify global shape matches + assert ( + original_value.shape == reconstructed_value.shape + ), f"Global shape mismatch for {key}" + + # Verify placements match + assert ( + original_value.placements == reconstructed_value.placements + ), f"Placements mismatch for {key}" + + elif isinstance(original_value, torch.Tensor): + # Regular tensors should remain the same + assert torch.equal( + original_value, reconstructed_value + ), f"Regular tensor mismatch for {key}" + else: + # Non-tensor values should be preserved + assert ( + original_value == reconstructed_value + ), f"Non-tensor value mismatch for {key}" + + class UnitModule(nn.Module): def __init__(self, device: torch.device): super().__init__() @@ -167,8 +268,9 @@ async def do_get(self): @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio -async def test_state_dict(strategy_params, use_rdma): +async def test_state_dict_lky(strategy_params, use_rdma): os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" + os.environ[TORCHSTORE_TSSD_ENABLED_FLAG] = "1" class Trainer(Actor): # Monarch RDMA does not work outside of an actor, so we need @@ -296,5 +398,182 @@ def _assert_equal_state_dict(state_dict1, state_dict2): ), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}" +def test_torchstore_state_dict(): + """Test TorchStoreStateDict class with various tensor types and reconstruction.""" + + # Create a state dict with various tensor types and shapes + original_state_dict = { + # Scalar tensor (0D) + "scalar": torch.tensor(42.5, dtype=torch.float32), + # 1D tensors with different dtypes + "vector_float": torch.randn(10, dtype=torch.float32), + "vector_int": torch.randint(0, 100, (5,), dtype=torch.int64), + "vector_half": torch.randn(8, dtype=torch.float16), + # 2D tensors with different dtypes + "matrix_float": torch.randn(3, 4, dtype=torch.float32), + "matrix_double": torch.randn(2, 3, dtype=torch.float64), + "matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32), + # Nested structure + "model": { + "layer1": { + "weight": torch.randn(5, 3, dtype=torch.float32), + "bias": torch.randn(5, dtype=torch.float32), + }, + "layer2": { + "weight": torch.randn(2, 5, dtype=torch.float32), + "bias": torch.randn(2, dtype=torch.float32), + }, + }, + # Mixed with non-tensor data + "metadata": { + "epoch": 10, + "learning_rate": 0.001, + "optimizer_state": torch.randn(3, 3, dtype=torch.float32), + }, + # List with tensors (note: flattened state dict doesn't preserve list structure) + "layer_weights": [ + torch.randn(2, 2, dtype=torch.float32), + torch.tensor(123, dtype=torch.int32), + ], + } + + # Create TorchStoreStateDict + torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) + + # Verify blob properties + blob = torchstore_state_dict.tensor_blob + assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}" + assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D" + + # 1. Flatten original state dict + original_flattened, _ = flatten_state_dict(original_state_dict) + + # 2. Verify keys match between original flattened and torchstore flattened state dict + assert set(original_flattened.keys()) == set( + torchstore_state_dict.flattened_state_dict.keys() + ), "Keys don't match between original and torchstore flattened state dicts" + + # 3. Verify tensor references and calculate total size + _verify_tensor_references(torchstore_state_dict, original_flattened) + + # Calculate total size for blob verification + total_size = 0 + for key, original_value in original_flattened.items(): + if isinstance(original_value, torch.Tensor): + tensor_to_size = ( + original_value._local_tensor + if hasattr(original_value, "_local_tensor") + else original_value + ) + total_size += tensor_to_size.numel() * tensor_to_size.element_size() + + # Verify tensor blob size matches total size + assert ( + len(blob) == total_size + ), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}" + + # Reconstruct the state dict + reconstructed_state_dict = torchstore_state_dict.to_state_dict() + + # Compare flattened versions - simpler than recursive comparison + original_flattened, original_mapping = flatten_state_dict(original_state_dict) + reconstructed_flattened, reconstructed_mapping = flatten_state_dict( + reconstructed_state_dict + ) + + # Verify mappings are identical (structure preserved) + assert ( + original_mapping == reconstructed_mapping + ), "State dict structure mappings don't match" + + # Verify keys match + assert set(original_flattened.keys()) == set( + reconstructed_flattened.keys() + ), "Flattened keys don't match" + + # Verify reconstruction using utility function + _verify_reconstructed_state_dict(original_flattened, reconstructed_flattened) + + +def test_torchstore_state_dict_edge_cases(): + """Test edge cases for TorchStoreStateDict.""" + + # Test empty state dict + empty_dict = {} + torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict) + reconstructed = torchstore_state_dict.to_state_dict() + assert reconstructed == {} + + # Test state dict with no tensors + no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}} + torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors) + reconstructed = torchstore_state_dict.to_state_dict() + assert reconstructed == no_tensors + + # Test scalar tensor edge case + scalar_dict = {"scalar": torch.tensor(3.14159)} + torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) + # Check flattened state dict has TensorReference + reconstructed = torchstore_state_dict.to_state_dict() + assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"]) + + # Test different dtypes + dtype_dict = { + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (5,), dtype=torch.uint8), + "int8": torch.randint(-128, 127, (3,), dtype=torch.int8), + "int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16), + "bfloat16": torch.randn(3, dtype=torch.bfloat16), + } + + torchstore_state_dict = TorchStoreStateDict.from_state_dict(dtype_dict) + reconstructed = torchstore_state_dict.to_state_dict() + + for key in dtype_dict: + assert torch.equal( + dtype_dict[key], reconstructed[key] + ), f"Mismatch for dtype {key}" + + +def test_torchstore_state_dict_with_dtensor(): + """Test TorchStoreStateDict with DTensor support.""" + _setup_process_group() + + # Create single-device mesh (CPU only) + device_mesh = DeviceMesh("cpu", [0]) + + # Create DTensor from local tensor + local_tensor = torch.randn(4, 6, dtype=torch.float32) + dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) + + # Create state dict with DTensor and regular tensor + original_state_dict = { + "regular_tensor": torch.randn(3, 3), + "dtensor": dtensor, + "nested": { + "another_dtensor": DTensor.from_local( + torch.ones(2, 3), device_mesh, [Replicate()] + ), + "metadata": {"test": "value"}, + }, + } + + # Test serialization + torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) + + # Verify DTensor metadata is preserved using utility function + flattened_original, _ = flatten_state_dict(original_state_dict) + _verify_tensor_references(torchstore_state_dict, flattened_original) + + # Test deserialization + reconstructed_state_dict = torchstore_state_dict.to_state_dict() + + # Verify reconstruction using utility function + flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict) + _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed) + + dist.destroy_process_group() + + if __name__ == "__main__": main(__file__) diff --git a/torchstore/client.py b/torchstore/client.py index 52a6c16..7973249 100644 --- a/torchstore/client.py +++ b/torchstore/client.py @@ -13,6 +13,7 @@ from torchstore.controller import ObjectType from torchstore.logging import LatencyTracker +from torchstore.state_dict_utils import DELIM, FLATTENED_STATE_DICT, get_state_dict_key from torchstore.transport import Pipe, Request, TensorSlice from torchstore.utils import assemble_global_tensor, get_local_tensor @@ -54,7 +55,18 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]): await pipe.put_to_storage_volume(key, request) latency_tracker.track_step("put_to_storage_volume") - await self._controller.notify_put.call(key, request.meta_only(), volume_id) + if key.endswith(FLATTENED_STATE_DICT): + state_dict_key = get_state_dict_key(key) + for flattened_key in value.keys(): + flattened_key = f"{state_dict_key}{DELIM}{flattened_key}" + await self._controller.notify_put.call( + flattened_key, + request.meta_only(), + volume_id, + ) + else: + await self._controller.notify_put.call(key, request.meta_only(), volume_id) + latency_tracker.track_step("notify_put") latency_tracker.track_e2e() diff --git a/torchstore/dtensor_utils.py b/torchstore/dtensor_utils.py new file mode 100644 index 0000000..64100d3 --- /dev/null +++ b/torchstore/dtensor_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torch.distributed.tensor import DTensor, Placement +from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset + + +def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": + """ + Create a TensorSlice from a DTensor. + + Args: + dtensor: The DTensor to extract metadata from + + Returns: + TensorSlice containing the distributed tensor metadata + """ + from torchstore.transport.pipe import TensorSlice + + coordinates = dtensor.device_mesh.get_coordinate() + _, offsets = _compute_local_shape_and_global_offset( + dtensor.shape, + mesh_shape=dtensor.device_mesh.shape, + my_coordinate=coordinates, + placements=dtensor.placements, + ) + + return TensorSlice( + offsets=offsets, + coordinates=coordinates, + global_shape=dtensor.shape, + local_shape=dtensor._local_tensor.shape, + mesh_shape=dtensor.device_mesh.shape, + ) + + +def reconstruct_dtensor_from_local_tensor( + local_tensor: torch.Tensor, + tensor_slice: "TensorSlice", + device_mesh: torch.distributed.DeviceMesh, + placements: Tuple[Placement, ...], +) -> DTensor: + """ + Reconstruct a DTensor from local tensor data and TensorSlice metadata. + + Args: + local_tensor: The local tensor shard + tensor_slice: TensorSlice containing distributed metadata + device_mesh: The device mesh for the DTensor + placements: The placements for the DTensor + + Returns: + Reconstructed DTensor + """ + return DTensor.from_local( + local_tensor=local_tensor, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 563abb9..e10ac3e 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -4,85 +4,175 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os +from dataclasses import dataclass from logging import getLogger -from typing import Optional +from typing import Any, Dict, List, Optional, Set, Tuple import torch from torch.distributed.checkpoint._nested_dict import ( flatten_state_dict, unflatten_state_dict, ) +from torch.distributed.tensor import DTensor + +from torchstore.dtensor_utils import create_tensor_slice_from_dtensor +from torchstore.transport.pipe import TensorSlice DELIM = "/" -MAPPING = "MAPPING" + +# Resereved key segments for torchstore internal handling of state dict +MAPPING = "__TORCHSTORE_STATE_DICT_MAPPING" +TENSOR_BLOB = "__TORCHSTORE_STATE_DICT_TENSOR_BLOB" +FLATTENED_STATE_DICT = "__TORCHSTORE_STATE_DICT_FLATTENED_STATE_DICT" + +TORCHSTORE_TSSD_ENABLED_FLAG = "TORCHSTORE_TSSD_ENABLED" logger = getLogger(__name__) -async def put_state_dict(store, state_dict, key): +def tssd_enabled() -> bool: + """ + Check if TorchStoreStateDict is enabled for put and get. If enabled, we will use the + TSSD to batch tensors in the state dict into one blob and transfer it more efficiently. + """ + + return os.environ.get(TORCHSTORE_TSSD_ENABLED_FLAG, "0") == "1" + + +def is_tssd_key(key: str) -> bool: + """ + Check if a key is a TorchStoreStateDict key. This is used to determine if we should use + the TSSD method for put and get. + """ + return ( + key.endswith(DELIM + MAPPING) + or key.endswith(DELIM + TENSOR_BLOB) + or key.endswith(DELIM + FLATTENED_STATE_DICT) + ) + + +def tssd_keys(state_dict_key: str) -> Set[str]: + """ + Get all TorchStoreStateDict keys for a given key. This is used to determine if we should use + the TSSD method for put and get. + Args: + state_dict_key: The key of the whole state dict without any internal segments. + """ + return { + state_dict_key + DELIM + MAPPING, + state_dict_key + DELIM + TENSOR_BLOB, + state_dict_key + DELIM + FLATTENED_STATE_DICT, + } + + +def get_state_dict_key(key: str) -> str: """ - We have an option here. Either we can "flatten state dict", by turning state dict names into a single key, - or I can actually just maintain the dictionary representation of the state dict, and we can allow some - recursive behavior in the store. + Get the key of the whole state dict from a TorchStoreStateDict key. This is used to determine if we should use + the TSSD method for put and get. + Args: + key: The key of the whole state dict without any internal segments. + """ + return key.split(DELIM)[0] + - Overall, this might not even be something we want to solve for in the TorchStore, but I'm adding this - utility so we can test sharding models. +async def put_state_dict(store, state_dict, key): + """ + Store a state dict using either the original method or TorchStoreStateDict. + Args: + store: The torchstore instance to store data in + state_dict: The state dictionary to store + key: The key prefix to store under """ - flattened_state_dict, mapping = flatten_state_dict(state_dict) - for flattened_key, value in flattened_state_dict.items(): - await store.put(f"{key}{DELIM}{flattened_key}", value) + if tssd_enabled(): + # Use TorchStoreStateDict method for efficient tensor serialization + torchstore_state_dict = TorchStoreStateDict.from_state_dict(state_dict) + + # Store the tensor blob + await store.put(f"{key}{DELIM}{TENSOR_BLOB}", torchstore_state_dict.tensor_blob) + + # Store the flattened state dict (contains TensorReferences and non-tensor data) + await store.put( + f"{key}{DELIM}{FLATTENED_STATE_DICT}", + torchstore_state_dict.flattened_state_dict, + ) + + # Store the mapping (this serves as the completion indicator) + await store.put(f"{key}{DELIM}{MAPPING}", torchstore_state_dict.mapping) + else: + # Original method: flatten and store each tensor individually + flattened_state_dict, mapping = flatten_state_dict(state_dict) + for flattened_key, value in flattened_state_dict.items(): + await store.put(f"{key}{DELIM}{flattened_key}", value) - await store.put(f"{key}{DELIM}{MAPPING}", mapping) + await store.put(f"{key}{DELIM}{MAPPING}", mapping) async def get_state_dict( - store, key, user_state_dict: Optional[dict] = None, strict=True + store, + key, + user_state_dict: Optional[dict] = None, + strict=True, ): - """Unflatten the state dict from the store""" + """ + Get a state dict from the store using either the original method or TorchStoreStateDict. + Args: + store: The torchstore instance to get data from + key: The key prefix to retrieve from + user_state_dict: Optional user state dict for validation/inplace tensors + strict: Whether to strictly validate mappings + """ try: - # Since the mapping is the last thing we write out, it also gaurantees the state dict is not pending + # Since the mapping is the last thing we write out, it also guarantees the state dict is not pending fetched_mapping = await store.get(f"{key}{DELIM}{MAPPING}") except Exception as e: raise RuntimeError( f"Mapping is missing from the store. This most likely means there is no matching 'push' call for this key: {key=}" ) from e - user_flattened_state_dict, user_mapping = ( - flatten_state_dict(user_state_dict) - if user_state_dict is not None - else ({}, None) - ) - if strict and user_mapping is not None: - assert user_mapping == fetched_mapping - - fetched_state_dict = {} - for flattened_key in fetched_mapping.keys(): - inplace_tensor = user_flattened_state_dict.get(flattened_key, None) - fetched_state_dict[flattened_key] = await store.get( - f"{key}{DELIM}{flattened_key}", - inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None, + if False: + # Use TorchStoreStateDict method for efficient retrieval + try: + # Get the tensor blob and flattened state dict + tensor_blob = await store.get(f"{key}{DELIM}{TENSOR_BLOB}") + flattened_state_dict = await store.get( + f"{key}{DELIM}{FLATTENED_STATE_DICT}" + ) + + # Reconstruct TorchStoreStateDict and convert back to state dict + torchstore_state_dict = TorchStoreStateDict( + tensor_blob=tensor_blob, + flattened_state_dict=flattened_state_dict, + mapping=fetched_mapping, + ) + + return torchstore_state_dict.to_state_dict() + + except Exception as e: + raise RuntimeError( + f"Failed to retrieve TorchStoreStateDict data for key: {key=}" + ) from e + else: + # Original method: get each tensor individually + user_flattened_state_dict, user_mapping = ( + flatten_state_dict(user_state_dict) + if user_state_dict is not None + else ({}, None) ) + if strict and user_mapping is not None: + assert user_mapping == fetched_mapping - # # Prepare all the coroutines first - # coros = [] - # keys = [] - # for flattened_key in fetched_mapping.keys(): - # inplace_tensor = user_flattened_state_dict.get(flattened_key, None) - # keys.append(flattened_key) - # coros.append( - # store.get( - # f"{key}{DELIM}{flattened_key}", - # inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None, - # ) - # ) - # # Run all requests concurrently - # results = await asyncio.gather(*coros) - # # Build the result dictionary - # fetched_state_dict = dict(zip(keys, results)) - - return unflatten_state_dict(fetched_state_dict, fetched_mapping) + fetched_state_dict = {} + for flattened_key in fetched_mapping.keys(): + inplace_tensor = user_flattened_state_dict.get(flattened_key, None) + fetched_state_dict[flattened_key] = await store.get( + f"{key}{DELIM}{flattened_key}", + inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None, + ) + + return unflatten_state_dict(fetched_state_dict, fetched_mapping) def _state_dict_size(state_dict): @@ -95,3 +185,162 @@ def _state_dict_size(state_dict): size += tensor.numel() * tensor.element_size() return size // (1024 * 1024) + + +@dataclass +class TensorReference: + """Metadata for a tensor in a tensor blob""" + + shape: Tuple[int, ...] + dtype: torch.dtype + offset: int # Byte offset in the blob + size: int # Size in bytes + tensor_slice: TensorSlice | None = None # TensorSlice for DTensor reconstruction + device_mesh: Any | None = None # DeviceMesh for DTensor reconstruction + placements: Tuple[Any, ...] | None = None # Placements for DTensor reconstruction + + +class TorchStoreStateDict: + """ + A torchstore representation of a state dict. It contains a flattened state dict and a tensor blob. + All of the tensors in the flattened state dict are replaced with TensorReference objects. + """ + + def __init__( + self, + tensor_blob: torch.Tensor, + flattened_state_dict: Dict[str, Any], + mapping: Dict[str, Any], + ): + self.tensor_blob = tensor_blob + self.flattened_state_dict = flattened_state_dict + self.mapping = mapping + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Any]) -> "TorchStoreStateDict": + """ + Create a TorchStoreStateDict from a state_dict. All tensors in the state_dict are replaced with + TensorReference objects. The tensor blob is created by concatenating all tensors in the state_dict. + """ + # 1. flatten the state dict + flattened_state_dict, mapping = flatten_state_dict(state_dict) + + # 2. iterate through the flattened state dict, collect all tensors and replace them with TensorReference objects + tensor_list: List[Tuple[torch.Tensor, TensorReference]] = [] + modified_flattened_state_dict = {} + current_offset = 0 + + for key, value in flattened_state_dict.items(): + if isinstance(value, DTensor): + # Handle DTensor: store local tensor and add TensorSlice metadata + local_tensor = value._local_tensor + tensor_size = local_tensor.numel() * local_tensor.element_size() + tensor_slice = create_tensor_slice_from_dtensor(value) + + ref = TensorReference( + shape=tuple(local_tensor.shape), + dtype=local_tensor.dtype, + offset=current_offset, + size=tensor_size, + tensor_slice=tensor_slice, + device_mesh=value.device_mesh, + placements=value.placements, + ) + tensor_list.append((local_tensor, ref)) + modified_flattened_state_dict[key] = ref + current_offset += tensor_size + elif isinstance(value, torch.Tensor): + # Handle regular tensor + tensor_size = value.numel() * value.element_size() + ref = TensorReference( + shape=tuple(value.shape), + dtype=value.dtype, + offset=current_offset, + size=tensor_size, + ) + tensor_list.append((value, ref)) + modified_flattened_state_dict[key] = ref + current_offset += tensor_size + else: + modified_flattened_state_dict[key] = value + + # 3. create the tensor blob by concatenating all tensors + if not tensor_list: + blob = torch.empty(0, dtype=torch.uint8) + else: + blob = torch.empty(current_offset, dtype=torch.uint8) + + # Copy tensor data + for tensor, ref in tensor_list: + # Handle scalar tensors + tensor_cpu = tensor.detach().cpu() + if tensor_cpu.dim() == 0: + tensor_cpu = tensor_cpu.unsqueeze(0) + + byte_view = tensor_cpu.view(torch.uint8).flatten() + + # Copy to blob + blob[ref.offset : ref.offset + ref.size] = byte_view + + # 4. return the TorchStoreStateDict object + return cls(blob, modified_flattened_state_dict, mapping) + + def to_state_dict(self) -> Dict[str, Any]: + """ + Convert the TorchStoreStateDict back to a state_dict. All TensorReference objects are replaced with + the corresponding tensors from the tensor blob. DTensors are reconstructed using stored metadata. + """ + state_dict = unflatten_state_dict( + deref_flattened_state_dict(self.flattened_state_dict, self.tensor_blob), + self.mapping, + ) + + # 3. return the state dict + return state_dict + + +def deref_flattened_state_dict( + flattened_state_dict: Dict[str, Any], + tensor_blob: torch.Tensor, +) -> Dict[str, Any]: + from torchstore.dtensor_utils import reconstruct_dtensor_from_local_tensor + + """ + Dereference a flattened state dict. All TensorReference objects are replaced with + the corresponding tensors from the tensor blob. + """ + derefed_flattened_state_dict = {} + + for key, value in flattened_state_dict.items(): + if isinstance(value, TensorReference): + # Pre-allocate tensor with correct shape and dtype (TorchStore approach) + tensor = torch.empty(value.shape, dtype=value.dtype) + + # Get byte view of the allocated tensor + if tensor.dim() == 0: + tensor_unsqueezed = tensor.unsqueeze(0) + byte_view = tensor_unsqueezed.view(torch.uint8).flatten() + else: + byte_view = tensor.view(torch.uint8).flatten() + + # Copy bytes from blob into tensor's byte view + tensor_bytes = tensor_blob[value.offset : value.offset + value.size] + byte_view.copy_(tensor_bytes) + + # Check if this should be reconstructed as a DTensor + if ( + value.tensor_slice is not None + and value.device_mesh is not None + and value.placements is not None + ): + tensor = reconstruct_dtensor_from_local_tensor( + local_tensor=tensor, + tensor_slice=value.tensor_slice, + device_mesh=value.device_mesh, + placements=value.placements, + ) + + derefed_flattened_state_dict[key] = tensor + else: + derefed_flattened_state_dict[key] = value + return derefed_flattened_state_dict diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 355fea5..9bc7a57 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,6 +11,17 @@ import torch from monarch.actor import Actor, endpoint +from torchstore.state_dict_utils import ( + DELIM, + deref_flattened_state_dict, + FLATTENED_STATE_DICT, + get_state_dict_key, + is_tssd_key, + TENSOR_BLOB, + tssd_enabled, + tssd_keys, +) + from torchstore.transport.buffers import TransportBuffer from torchstore.transport.pipe import Request, TensorSlice @@ -174,7 +185,7 @@ def _has_full_tensor(self, key: str) -> bool: return True - def _handle_dtensor( + def _handle_put_dtensor( self, key: str, tensor_slice: TensorSlice, tensor: torch.Tensor ) -> None: if key not in self.kv: @@ -185,9 +196,58 @@ def _handle_dtensor( "tensor": tensor, } + async def _handle_put_tssd( + self, key: str, transport_buffer: TransportBuffer, request: Request + ) -> None: + # 1. verify again that this is a tssd key + if not is_tssd_key(key): + raise ValueError( + f"{key} is not an internal key for Torchstore state dict handling" + ) + + # 2. early return if not all necessary parts are available + state_dict_key = get_state_dict_key(key) + if not tssd_keys(state_dict_key).issubset(self.kv.keys()): + return + + # 3. update flattened state dict by replacing tensor ref with actual tensor + tensor_blob_key = f"{state_dict_key}{DELIM}{TENSOR_BLOB}" + flattened_state_dict_key = f"{state_dict_key}{DELIM}{FLATTENED_STATE_DICT}" + derefed_flattened_state_dict = deref_flattened_state_dict( + # TODO consolidate "obj" + self.kv[flattened_state_dict_key]["obj"], + self.kv[tensor_blob_key], + ) + + # 4. clean up blob and state dict with tensor refs + del self.kv[tensor_blob_key] + del self.kv[flattened_state_dict_key] + + # 5. put every flattened key entry into the kv store + subkeys = list(derefed_flattened_state_dict.keys()) + for subkey in subkeys: + value = derefed_flattened_state_dict.pop(subkey) + subkey = f"{state_dict_key}{DELIM}{subkey}" + self.kv[subkey] = value + async def put( self, key: str, transport_buffer: TransportBuffer, request: Request ) -> None: + print(f"putting {key} {request=}") + + await self.put_impl(key=key, transport_buffer=transport_buffer, request=request) + + # handle state dict put with tensor blob. + if tssd_enabled() and is_tssd_key(key): + await self._handle_put_tssd(key, transport_buffer, request) + return + + async def put_impl( + self, key: str, transport_buffer: TransportBuffer, request: Request + ) -> None: + """ + Put object / tensor / dtensor into kv store. + """ if request.is_object: self.kv[key] = {"obj": request.objects} return @@ -196,7 +256,7 @@ async def put( # we allocate on the fly tensor = await transport_buffer.read_into(tensor=None) if request.tensor_slice is not None: - self._handle_dtensor(key, request.tensor_slice, tensor) + self._handle_put_dtensor(key, request.tensor_slice, tensor) return self.kv[key] = tensor diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..1e0f5ef 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -11,8 +11,8 @@ import torch from torch.distributed.tensor import DTensor -from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset +from torchstore.dtensor_utils import create_tensor_slice_from_dtensor from torchstore.transport.buffers import ( MonarchTransportBuffer, rdma_available, @@ -84,21 +84,7 @@ def from_any(cls, value: torch.Tensor | DTensor | None) -> "Request": @classmethod def from_dtensor(cls, dtensor: DTensor) -> "Request": - coordinates = dtensor.device_mesh.get_coordinate() - _, offsets = _compute_local_shape_and_global_offset( - dtensor.shape, - mesh_shape=dtensor.device_mesh.shape, - my_coordinate=coordinates, - placements=dtensor.placements, - ) - - tensor_slice = TensorSlice( - offsets, - coordinates, - dtensor.shape, - dtensor._local_tensor.shape, - dtensor.device_mesh.shape, - ) + tensor_slice = create_tensor_slice_from_dtensor(dtensor) return cls( tensor_val=dtensor._local_tensor, tensor_slice=tensor_slice,