From 9567b1f1e1f5ca481a39f7070f5e32dddd197069 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Mon, 6 Oct 2025 08:31:26 -0700 Subject: [PATCH 1/8] create functions --- torchstore/state_dict_utils.py | 118 ++++++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 563abb9..14be423 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from logging import getLogger -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple import torch from torch.distributed.checkpoint._nested_dict import ( @@ -95,3 +96,118 @@ def _state_dict_size(state_dict): size += tensor.numel() * tensor.element_size() return size // (1024 * 1024) + + +@dataclass +class TensorReference: + """Metadata for a tensor in a tensor blob""" + + shape: Tuple[int, ...] + dtype: torch.dtype + offset: int # Byte offset in the blob + size: int # Size in bytes + + +def generate_tensor_blob(state_dict: Dict[str, Any]): + """ + Extract all tensors from state_dict and create a blob. Replace the tensors + with corresponding references and returns a state_dict with only tensor references, + and the tensor blob. + + Args: + state_dict: Dictionary that may contain tensors at any level + + Returns: + - Modified dictionary with tensors replaced by TensorReference objects + - 1D uint8 tensor blob containing all serialized tensor data + """ + + def _extract_recursive( + obj: Dict[str, Any], + tensor_list: List[Tuple[torch.Tensor, TensorReference]], + path: str = "", + ): + """Recursively extract tensors and replace with TensorReference objects""" + if isinstance(obj, torch.Tensor): + # Create placeholder reference (offset will be filled later) + ref = TensorReference( + shape=tuple(obj.shape), + dtype=obj.dtype, + offset=-1, # Will be updated when building blob + size=obj.numel() * obj.element_size(), + ) + tensor_list.append((obj, ref)) + return ref # Replace tensor with TensorReference + elif isinstance(obj, dict): + return { + k: _extract_recursive(v, tensor_list, f"{path}.{k}") + for k, v in obj.items() + } + elif isinstance(obj, (list, tuple)): + return type(obj)( + _extract_recursive(item, tensor_list, f"{path}[{i}]") + for i, item in enumerate(obj) + ) + else: + return obj # Non-tensor data stays as-is + + tensor_list: List[Tuple[torch.Tensor, TensorReference]] = [] + + modified_state_dict = _extract_recursive(state_dict, tensor_list) + + if not tensor_list: + return modified_state_dict, torch.empty(0, dtype=torch.uint8) + + total_bytes = sum([ref.size for _, ref in tensor_list]) + + blob = torch.empty(total_bytes, dtype=torch.uint8) + + # Copy tensor data using your efficient approach + for tensor, ref in tensor_list: + # Handle scalar tensors + tensor_cpu = tensor.detach().cpu() + if tensor_cpu.dim() == 0: + tensor_cpu = tensor_cpu.unsqueeze(0) + + byte_view = tensor_cpu.view(torch.uint8).flatten() + + # Copy to blob + blob[ref.offset : ref.offset + ref.size] = byte_view + + return modified_state_dict, blob + + +def reconstruct_state_dict_from_tensor_blob( + state_dict_with_tensor_refs: Dict[str, Any], blob: torch.Tensor +) -> Dict[str, Any]: + """ + Reconstruct a state_dict which only contains tensor references by + reconstructing the tensors using the tensor blob and the tensor references. + Returns the reconstructed state dict. + """ + + def _reconstruct_recursive(obj): + if isinstance(obj, TensorReference): + # Pre-allocate tensor with correct shape and dtype (TorchStore approach) + tensor = torch.empty(obj.shape, dtype=obj.dtype) + + # Get byte view of the allocated tensor + if tensor.dim() == 0: + tensor_unsqueezed = tensor.unsqueeze(0) + byte_view = tensor_unsqueezed.view(torch.uint8).flatten() + else: + byte_view = tensor.view(torch.uint8).flatten() + + # Copy bytes from blob into tensor's byte view + tensor_bytes = blob[obj.offset : obj.offset + obj.size] + byte_view.copy_(tensor_bytes) + + return tensor + elif isinstance(obj, dict): + return {k: _reconstruct_recursive(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(_reconstruct_recursive(item) for item in obj) + else: + return obj + + return _reconstruct_recursive(state_dict_with_tensor_refs) From 43427f6c3b6b0a8b7812e54e53825104afbdf6e9 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Mon, 6 Oct 2025 08:45:11 -0700 Subject: [PATCH 2/8] add tests --- tests/test_state_dict.py | 192 +++++++++++++++++++++++++++++++++ torchstore/state_dict_utils.py | 8 +- 2 files changed, 198 insertions(+), 2 deletions(-) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index caf3c53..f1351db 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -28,6 +28,11 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import DTensor +from torchstore.state_dict_utils import ( + generate_tensor_blob, + reconstruct_state_dict_from_tensor_blob, + TensorReference, +) from torchstore.utils import spawn_actors from .utils import main, transport_plus_strategy_params @@ -296,5 +301,192 @@ def _assert_equal_state_dict(state_dict1, state_dict2): ), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}" +def test_generate_tensor_blob(): + """Test generate_tensor_blob with various tensor types and reconstruction.""" + + # Create a state dict with various tensor types and shapes + original_state_dict = { + # Scalar tensor (0D) + "scalar": torch.tensor(42.5, dtype=torch.float32), + # 1D tensors with different dtypes + "vector_float": torch.randn(10, dtype=torch.float32), + "vector_int": torch.randint(0, 100, (5,), dtype=torch.int64), + "vector_half": torch.randn(8, dtype=torch.float16), + # 2D tensors with different dtypes + "matrix_float": torch.randn(3, 4, dtype=torch.float32), + "matrix_double": torch.randn(2, 3, dtype=torch.float64), + "matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32), + # Nested structure + "model": { + "layer1": { + "weight": torch.randn(5, 3, dtype=torch.float32), + "bias": torch.randn(5, dtype=torch.float32), + }, + "layer2": { + "weight": torch.randn(2, 5, dtype=torch.float32), + "bias": torch.randn(2, dtype=torch.float32), + }, + }, + # Mixed with non-tensor data + "metadata": { + "epoch": 10, + "learning_rate": 0.001, + "optimizer_state": torch.randn(3, 3, dtype=torch.float32), + }, + # List with tensors + "tensor_list": [ + torch.randn(2, 2, dtype=torch.float32), + torch.tensor(123, dtype=torch.int32), + ], + } + + # Generate tensor blob + modified_state_dict, blob = generate_tensor_blob(original_state_dict) + + # Verify blob properties + assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}" + assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D" + + # Calculate expected blob size + expected_size = 0 + + def calculate_expected_size(obj): + nonlocal expected_size + if isinstance(obj, torch.Tensor): + expected_size += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + calculate_expected_size(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + calculate_expected_size(item) + + calculate_expected_size(original_state_dict) + assert ( + len(blob) == expected_size + ), f"Expected blob size {expected_size}, got {len(blob)}" + + # Verify that tensors are replaced with TensorReference objects + def verify_tensor_references(obj, path=""): + if isinstance(obj, TensorReference): + assert obj.shape is not None, f"TensorReference at {path} missing shape" + assert obj.dtype is not None, f"TensorReference at {path} missing dtype" + assert ( + obj.offset >= 0 + ), f"TensorReference at {path} has invalid offset {obj.offset}" + assert ( + obj.size > 0 + ), f"TensorReference at {path} has invalid size {obj.size}" + elif isinstance(obj, dict): + for k, v in obj.items(): + verify_tensor_references(v, f"{path}.{k}" if path else k) + elif isinstance(obj, (list, tuple)): + for i, item in enumerate(obj): + verify_tensor_references(item, f"{path}[{i}]") + elif isinstance(obj, torch.Tensor): + raise AssertionError(f"Found unreplaced tensor at {path}") + + verify_tensor_references(modified_state_dict) + + # Verify that non-tensor data is preserved + assert modified_state_dict["metadata"]["epoch"] == 10 + assert modified_state_dict["metadata"]["learning_rate"] == 0.001 + + # Reconstruct the state dict + reconstructed_state_dict = reconstruct_state_dict_from_tensor_blob( + modified_state_dict, blob + ) + + # Verify reconstruction matches original + def compare_state_dicts(original, reconstructed, path=""): + if isinstance(original, torch.Tensor): + assert isinstance(reconstructed, torch.Tensor), f"Expected tensor at {path}" + assert ( + original.shape == reconstructed.shape + ), f"Shape mismatch at {path}: {original.shape} vs {reconstructed.shape}" + assert ( + original.dtype == reconstructed.dtype + ), f"Dtype mismatch at {path}: {original.dtype} vs {reconstructed.dtype}" + assert torch.equal(original, reconstructed), f"Values mismatch at {path}" + elif isinstance(original, dict): + assert isinstance(reconstructed, dict), f"Expected dict at {path}" + assert set(original.keys()) == set( + reconstructed.keys() + ), f"Key mismatch at {path}" + for k in original.keys(): + compare_state_dicts( + original[k], reconstructed[k], f"{path}.{k}" if path else k + ) + elif isinstance(original, (list, tuple)): + assert type(original) == type(reconstructed), f"Type mismatch at {path}" + assert len(original) == len(reconstructed), f"Length mismatch at {path}" + for i, (orig_item, recon_item) in enumerate(zip(original, reconstructed)): + compare_state_dicts(orig_item, recon_item, f"{path}[{i}]") + else: + assert ( + original == reconstructed + ), f"Value mismatch at {path}: {original} vs {reconstructed}" + + compare_state_dicts(original_state_dict, reconstructed_state_dict) + + print("✅ test_generate_tensor_blob passed!") + print( + f" Processed {len([x for x in str(modified_state_dict) if 'TensorReference' in str(x)])} tensors" + ) + print(f" Blob size: {len(blob)} bytes ({len(blob) / 1024:.1f} KB)") + + +def test_generate_tensor_blob_edge_cases(): + """Test edge cases for generate_tensor_blob.""" + + # Test empty state dict + empty_dict = {} + modified, blob = generate_tensor_blob(empty_dict) + assert modified == {} + assert len(blob) == 0 + reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + assert reconstructed == {} + + # Test state dict with no tensors + no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}} + modified, blob = generate_tensor_blob(no_tensors) + assert modified == no_tensors + assert len(blob) == 0 + reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + assert reconstructed == no_tensors + + # Test scalar tensor edge case + scalar_dict = {"scalar": torch.tensor(3.14159)} + modified, blob = generate_tensor_blob(scalar_dict) + assert isinstance(modified["scalar"], TensorReference) + assert modified["scalar"].shape == () # Empty tuple for scalar + reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"]) + + # Test different dtypes + dtype_dict = { + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (5,), dtype=torch.uint8), + "int8": torch.randint(-128, 127, (3,), dtype=torch.int8), + "int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16), + "bfloat16": torch.randn(3, dtype=torch.bfloat16), + } + + modified, blob = generate_tensor_blob(dtype_dict) + reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + + for key in dtype_dict: + assert torch.equal( + dtype_dict[key], reconstructed[key] + ), f"Mismatch for dtype {key}" + + print("✅ test_generate_tensor_blob_edge_cases passed!") + + if __name__ == "__main__": + # Run our new tests + test_generate_tensor_blob() + test_generate_tensor_blob_edge_cases() + + # Run existing tests main(__file__) diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 14be423..7fcbb56 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -158,9 +158,13 @@ def _extract_recursive( if not tensor_list: return modified_state_dict, torch.empty(0, dtype=torch.uint8) - total_bytes = sum([ref.size for _, ref in tensor_list]) + # Calculate total size and update offsets + current_offset = 0 + for tensor, ref in tensor_list: + ref.offset = current_offset + current_offset += ref.size - blob = torch.empty(total_bytes, dtype=torch.uint8) + blob = torch.empty(current_offset, dtype=torch.uint8) # Copy tensor data using your efficient approach for tensor, ref in tensor_list: From c1da899344f32e63b9507e7d72755eaedf6c2246 Mon Sep 17 00:00:00 2001 From: Kaiyuan Li Date: Mon, 6 Oct 2025 13:36:12 -0400 Subject: [PATCH 3/8] sync --- tests/test_state_dict.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index f1351db..edfb3cb 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -484,9 +484,5 @@ def test_generate_tensor_blob_edge_cases(): if __name__ == "__main__": - # Run our new tests - test_generate_tensor_blob() - test_generate_tensor_blob_edge_cases() - # Run existing tests main(__file__) From 1b35e01d7593015f62f64b8f5f48e3dbf8e2cecf Mon Sep 17 00:00:00 2001 From: Kai Li Date: Tue, 7 Oct 2025 11:43:53 -0700 Subject: [PATCH 4/8] TorchstoreStateDict --- tests/test_state_dict.py | 222 ++++++++++++++++++--------------- torchstore/state_dict_utils.py | 198 ++++++++++++++--------------- 2 files changed, 217 insertions(+), 203 deletions(-) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index edfb3cb..f1fcba2 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -16,10 +16,10 @@ import torch import torch.distributed.checkpoint as dcp import torch.nn as nn - import torchstore as ts from monarch.actor import Actor, current_rank, endpoint + from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -28,11 +28,7 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import DTensor -from torchstore.state_dict_utils import ( - generate_tensor_blob, - reconstruct_state_dict_from_tensor_blob, - TensorReference, -) +from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict from torchstore.utils import spawn_actors from .utils import main, transport_plus_strategy_params @@ -301,8 +297,8 @@ def _assert_equal_state_dict(state_dict1, state_dict2): ), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}" -def test_generate_tensor_blob(): - """Test generate_tensor_blob with various tensor types and reconstruction.""" +def test_torchstore_state_dict(): + """Test TorchStoreStateDict class with various tensor types and reconstruction.""" # Create a state dict with various tensor types and shapes original_state_dict = { @@ -333,134 +329,152 @@ def test_generate_tensor_blob(): "learning_rate": 0.001, "optimizer_state": torch.randn(3, 3, dtype=torch.float32), }, - # List with tensors - "tensor_list": [ + # List with tensors (note: flattened state dict doesn't preserve list structure) + "layer_weights": [ torch.randn(2, 2, dtype=torch.float32), torch.tensor(123, dtype=torch.int32), ], } - # Generate tensor blob - modified_state_dict, blob = generate_tensor_blob(original_state_dict) + # Create TorchStoreStateDict + torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) # Verify blob properties + blob = torchstore_state_dict.tensor_blob assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}" assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D" - # Calculate expected blob size - expected_size = 0 - - def calculate_expected_size(obj): - nonlocal expected_size - if isinstance(obj, torch.Tensor): - expected_size += obj.numel() * obj.element_size() - elif isinstance(obj, dict): - for v in obj.values(): - calculate_expected_size(v) - elif isinstance(obj, (list, tuple)): - for item in obj: - calculate_expected_size(item) - - calculate_expected_size(original_state_dict) - assert ( - len(blob) == expected_size - ), f"Expected blob size {expected_size}, got {len(blob)}" - - # Verify that tensors are replaced with TensorReference objects - def verify_tensor_references(obj, path=""): - if isinstance(obj, TensorReference): - assert obj.shape is not None, f"TensorReference at {path} missing shape" - assert obj.dtype is not None, f"TensorReference at {path} missing dtype" + # 1. Flatten original state dict + original_flattened, _ = flatten_state_dict(original_state_dict) + + # 2. Verify keys match between original flattened and torchstore flattened state dict + assert set(original_flattened.keys()) == set( + torchstore_state_dict.flattened_state_dict.keys() + ), "Keys don't match between original and torchstore flattened state dicts" + + # 3. For each key, verify tensor conversion and aggregate total size + total_size = 0 + for key in original_flattened.keys(): + original_value = original_flattened[key] + torchstore_value = torchstore_state_dict.flattened_state_dict[key] + + if isinstance(original_value, torch.Tensor): + # Should be converted to TensorReference + assert isinstance( + torchstore_value, TensorReference + ), f"Expected TensorReference for key {key}, got {type(torchstore_value)}" + + # Verify TensorReference properties + assert torchstore_value.shape == tuple( + original_value.shape + ), f"Shape mismatch for key {key}: {torchstore_value.shape} vs {tuple(original_value.shape)}" + assert ( + torchstore_value.dtype == original_value.dtype + ), f"Dtype mismatch for key {key}: {torchstore_value.dtype} vs {original_value.dtype}" assert ( - obj.offset >= 0 - ), f"TensorReference at {path} has invalid offset {obj.offset}" + torchstore_value.offset >= 0 + ), f"Invalid offset for key {key}: {torchstore_value.offset}" assert ( - obj.size > 0 - ), f"TensorReference at {path} has invalid size {obj.size}" - elif isinstance(obj, dict): - for k, v in obj.items(): - verify_tensor_references(v, f"{path}.{k}" if path else k) - elif isinstance(obj, (list, tuple)): - for i, item in enumerate(obj): - verify_tensor_references(item, f"{path}[{i}]") - elif isinstance(obj, torch.Tensor): - raise AssertionError(f"Found unreplaced tensor at {path}") - - verify_tensor_references(modified_state_dict) - - # Verify that non-tensor data is preserved - assert modified_state_dict["metadata"]["epoch"] == 10 - assert modified_state_dict["metadata"]["learning_rate"] == 0.001 + torchstore_value.size > 0 + ), f"Invalid size for key {key}: {torchstore_value.size}" + + # Aggregate total size + expected_tensor_size = ( + original_value.numel() * original_value.element_size() + ) + assert ( + torchstore_value.size == expected_tensor_size + ), f"Size mismatch for key {key}: {torchstore_value.size} vs {expected_tensor_size}" + total_size += torchstore_value.size + else: + # Non-tensor values should be preserved as-is + assert ( + torchstore_value == original_value + ), f"Non-tensor value mismatch for key {key}: {torchstore_value} vs {original_value}" + + # Verify tensor blob size matches total size + assert ( + len(blob) == total_size + ), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}" # Reconstruct the state dict - reconstructed_state_dict = reconstruct_state_dict_from_tensor_blob( - modified_state_dict, blob + reconstructed_state_dict = torchstore_state_dict.to_state_dict() + + # Compare flattened versions - simpler than recursive comparison + original_flattened, original_mapping = flatten_state_dict(original_state_dict) + reconstructed_flattened, reconstructed_mapping = flatten_state_dict( + reconstructed_state_dict ) - # Verify reconstruction matches original - def compare_state_dicts(original, reconstructed, path=""): - if isinstance(original, torch.Tensor): - assert isinstance(reconstructed, torch.Tensor), f"Expected tensor at {path}" + # Verify mappings are identical (structure preserved) + assert ( + original_mapping == reconstructed_mapping + ), "State dict structure mappings don't match" + + # Verify keys match + assert set(original_flattened.keys()) == set( + reconstructed_flattened.keys() + ), "Flattened keys don't match" + + # Compare each tensor/value + for key in original_flattened.keys(): + original_value = original_flattened[key] + reconstructed_value = reconstructed_flattened[key] + + if isinstance(original_value, torch.Tensor): + assert isinstance( + reconstructed_value, torch.Tensor + ), f"Expected tensor for key {key}" assert ( - original.shape == reconstructed.shape - ), f"Shape mismatch at {path}: {original.shape} vs {reconstructed.shape}" + original_value.shape == reconstructed_value.shape + ), f"Shape mismatch for key {key}" assert ( - original.dtype == reconstructed.dtype - ), f"Dtype mismatch at {path}: {original.dtype} vs {reconstructed.dtype}" - assert torch.equal(original, reconstructed), f"Values mismatch at {path}" - elif isinstance(original, dict): - assert isinstance(reconstructed, dict), f"Expected dict at {path}" - assert set(original.keys()) == set( - reconstructed.keys() - ), f"Key mismatch at {path}" - for k in original.keys(): - compare_state_dicts( - original[k], reconstructed[k], f"{path}.{k}" if path else k - ) - elif isinstance(original, (list, tuple)): - assert type(original) == type(reconstructed), f"Type mismatch at {path}" - assert len(original) == len(reconstructed), f"Length mismatch at {path}" - for i, (orig_item, recon_item) in enumerate(zip(original, reconstructed)): - compare_state_dicts(orig_item, recon_item, f"{path}[{i}]") + original_value.dtype == reconstructed_value.dtype + ), f"Dtype mismatch for key {key}" + assert torch.equal( + original_value, reconstructed_value + ), f"Values mismatch for key {key}" else: assert ( - original == reconstructed - ), f"Value mismatch at {path}: {original} vs {reconstructed}" - - compare_state_dicts(original_state_dict, reconstructed_state_dict) - - print("✅ test_generate_tensor_blob passed!") - print( - f" Processed {len([x for x in str(modified_state_dict) if 'TensorReference' in str(x)])} tensors" + original_value == reconstructed_value + ), f"Non-tensor value mismatch for key {key}" + + print("✅ test_torchstore_state_dict passed!") + tensor_count = sum( + 1 + for v in torchstore_state_dict.flattened_state_dict.values() + if isinstance(v, TensorReference) ) + print(f" Processed {tensor_count} tensors") print(f" Blob size: {len(blob)} bytes ({len(blob) / 1024:.1f} KB)") -def test_generate_tensor_blob_edge_cases(): - """Test edge cases for generate_tensor_blob.""" +def test_torchstore_state_dict_edge_cases(): + """Test edge cases for TorchStoreStateDict.""" # Test empty state dict empty_dict = {} - modified, blob = generate_tensor_blob(empty_dict) - assert modified == {} - assert len(blob) == 0 - reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict) + assert torchstore_state_dict.flattened_state_dict == {} + assert len(torchstore_state_dict.tensor_blob) == 0 + reconstructed = torchstore_state_dict.to_state_dict() assert reconstructed == {} # Test state dict with no tensors no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}} - modified, blob = generate_tensor_blob(no_tensors) - assert modified == no_tensors - assert len(blob) == 0 - reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors) + assert len(torchstore_state_dict.tensor_blob) == 0 + reconstructed = torchstore_state_dict.to_state_dict() assert reconstructed == no_tensors # Test scalar tensor edge case scalar_dict = {"scalar": torch.tensor(3.14159)} - modified, blob = generate_tensor_blob(scalar_dict) - assert isinstance(modified["scalar"], TensorReference) - assert modified["scalar"].shape == () # Empty tuple for scalar - reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) + # Check flattened state dict has TensorReference + scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] + assert isinstance(scalar_ref, TensorReference) + assert scalar_ref.shape == () # Empty tuple for scalar + reconstructed = torchstore_state_dict.to_state_dict() assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"]) # Test different dtypes @@ -472,15 +486,15 @@ def test_generate_tensor_blob_edge_cases(): "bfloat16": torch.randn(3, dtype=torch.bfloat16), } - modified, blob = generate_tensor_blob(dtype_dict) - reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob) + torchstore_state_dict = TorchStoreStateDict.from_state_dict(dtype_dict) + reconstructed = torchstore_state_dict.to_state_dict() for key in dtype_dict: assert torch.equal( dtype_dict[key], reconstructed[key] ), f"Mismatch for dtype {key}" - print("✅ test_generate_tensor_blob_edge_cases passed!") + print("✅ test_torchstore_state_dict_edge_cases passed!") if __name__ == "__main__": diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 7fcbb56..2403ea5 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -108,110 +108,110 @@ class TensorReference: size: int # Size in bytes -def generate_tensor_blob(state_dict: Dict[str, Any]): +class TorchStoreStateDict: """ - Extract all tensors from state_dict and create a blob. Replace the tensors - with corresponding references and returns a state_dict with only tensor references, - and the tensor blob. - - Args: - state_dict: Dictionary that may contain tensors at any level - - Returns: - - Modified dictionary with tensors replaced by TensorReference objects - - 1D uint8 tensor blob containing all serialized tensor data + A torchstore representation of a state dict. It contains a flattened state dict and a tensor blob. + All of the tensors in the flattened state dict are replaced with TensorReference objects. """ - def _extract_recursive( - obj: Dict[str, Any], - tensor_list: List[Tuple[torch.Tensor, TensorReference]], - path: str = "", + def __init__( + self, + tensor_blob: torch.Tensor, + flattened_state_dict: Dict[str, Any], + mapping: Dict[str, Any], ): - """Recursively extract tensors and replace with TensorReference objects""" - if isinstance(obj, torch.Tensor): - # Create placeholder reference (offset will be filled later) - ref = TensorReference( - shape=tuple(obj.shape), - dtype=obj.dtype, - offset=-1, # Will be updated when building blob - size=obj.numel() * obj.element_size(), - ) - tensor_list.append((obj, ref)) - return ref # Replace tensor with TensorReference - elif isinstance(obj, dict): - return { - k: _extract_recursive(v, tensor_list, f"{path}.{k}") - for k, v in obj.items() - } - elif isinstance(obj, (list, tuple)): - return type(obj)( - _extract_recursive(item, tensor_list, f"{path}[{i}]") - for i, item in enumerate(obj) - ) - else: - return obj # Non-tensor data stays as-is - - tensor_list: List[Tuple[torch.Tensor, TensorReference]] = [] - - modified_state_dict = _extract_recursive(state_dict, tensor_list) - - if not tensor_list: - return modified_state_dict, torch.empty(0, dtype=torch.uint8) - - # Calculate total size and update offsets - current_offset = 0 - for tensor, ref in tensor_list: - ref.offset = current_offset - current_offset += ref.size - - blob = torch.empty(current_offset, dtype=torch.uint8) - - # Copy tensor data using your efficient approach - for tensor, ref in tensor_list: - # Handle scalar tensors - tensor_cpu = tensor.detach().cpu() - if tensor_cpu.dim() == 0: - tensor_cpu = tensor_cpu.unsqueeze(0) - - byte_view = tensor_cpu.view(torch.uint8).flatten() - - # Copy to blob - blob[ref.offset : ref.offset + ref.size] = byte_view - - return modified_state_dict, blob - - -def reconstruct_state_dict_from_tensor_blob( - state_dict_with_tensor_refs: Dict[str, Any], blob: torch.Tensor -) -> Dict[str, Any]: - """ - Reconstruct a state_dict which only contains tensor references by - reconstructing the tensors using the tensor blob and the tensor references. - Returns the reconstructed state dict. - """ - - def _reconstruct_recursive(obj): - if isinstance(obj, TensorReference): - # Pre-allocate tensor with correct shape and dtype (TorchStore approach) - tensor = torch.empty(obj.shape, dtype=obj.dtype) - - # Get byte view of the allocated tensor - if tensor.dim() == 0: - tensor_unsqueezed = tensor.unsqueeze(0) - byte_view = tensor_unsqueezed.view(torch.uint8).flatten() + """ + Create a TorchStoreStateDict from a tensor blob, flattened state_dict, and mapping. + """ + self.tensor_blob = tensor_blob + self.flattened_state_dict = flattened_state_dict + self.mapping = mapping + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Any]) -> "TorchStoreStateDict": + """ + Create a TorchStoreStateDict from a state_dict. All tensors in the state_dict are replaced with + TensorReference objects. The tensor blob is created by concatenating all tensors in the state_dict. + """ + # 1. flatten the state dict + flattened_state_dict, mapping = flatten_state_dict(state_dict) + + # 2. iterate through the flattened state dict, collect all tensors and replace them with TensorReference objects + tensor_list: List[Tuple[torch.Tensor, TensorReference]] = [] + modified_flattened_state_dict = {} + current_offset = 0 + + for key, value in flattened_state_dict.items(): + if isinstance(value, torch.Tensor): + # Calculate size and create reference with correct offset + tensor_size = value.numel() * value.element_size() + ref = TensorReference( + shape=tuple(value.shape), + dtype=value.dtype, + offset=current_offset, + size=tensor_size, + ) + tensor_list.append((value, ref)) + modified_flattened_state_dict[key] = ref + current_offset += tensor_size else: - byte_view = tensor.view(torch.uint8).flatten() - - # Copy bytes from blob into tensor's byte view - tensor_bytes = blob[obj.offset : obj.offset + obj.size] - byte_view.copy_(tensor_bytes) + modified_flattened_state_dict[key] = value - return tensor - elif isinstance(obj, dict): - return {k: _reconstruct_recursive(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(_reconstruct_recursive(item) for item in obj) + # 3. create the tensor blob by concatenating all tensors + if not tensor_list: + blob = torch.empty(0, dtype=torch.uint8) else: - return obj + blob = torch.empty(current_offset, dtype=torch.uint8) + + # Copy tensor data + for tensor, ref in tensor_list: + # Handle scalar tensors + tensor_cpu = tensor.detach().cpu() + if tensor_cpu.dim() == 0: + tensor_cpu = tensor_cpu.unsqueeze(0) + + byte_view = tensor_cpu.view(torch.uint8).flatten() + + # Copy to blob + blob[ref.offset : ref.offset + ref.size] = byte_view + + # 4. return the TorchStoreStateDict object + return cls(blob, modified_flattened_state_dict, mapping) + + def to_state_dict(self) -> Dict[str, Any]: + """ + Convert the TorchStoreStateDict back to a state_dict. All TensorReference objects are replaced with + the corresponding tensors from the tensor blob. + """ + # 1. iterate through the flattened state dict, replace TensorReference objects with tensors from the tensor blob + reconstructed_flattened_state_dict = {} + + for key, value in self.flattened_state_dict.items(): + if isinstance(value, TensorReference): + # Pre-allocate tensor with correct shape and dtype (TorchStore approach) + tensor = torch.empty(value.shape, dtype=value.dtype) + + # Get byte view of the allocated tensor + if tensor.dim() == 0: + tensor_unsqueezed = tensor.unsqueeze(0) + byte_view = tensor_unsqueezed.view(torch.uint8).flatten() + else: + byte_view = tensor.view(torch.uint8).flatten() + + # Copy bytes from blob into tensor's byte view + tensor_bytes = self.tensor_blob[ + value.offset : value.offset + value.size + ] + byte_view.copy_(tensor_bytes) + + reconstructed_flattened_state_dict[key] = tensor + else: + reconstructed_flattened_state_dict[key] = value + + # 2. unflatten the state dict + state_dict = unflatten_state_dict( + reconstructed_flattened_state_dict, self.mapping + ) - return _reconstruct_recursive(state_dict_with_tensor_refs) + # 3. return the state dict + return state_dict From f68ce3ffff82b75e4bedc2995d275336d1bc9538 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Tue, 7 Oct 2025 12:30:02 -0700 Subject: [PATCH 5/8] dtensor support --- tests/test_state_dict.py | 186 ++++++++++++++++++++++++++------- torchstore/dtensor_utils.py | 65 ++++++++++++ torchstore/state_dict_utils.py | 48 ++++++++- torchstore/transport/pipe.py | 18 +--- 4 files changed, 259 insertions(+), 58 deletions(-) create mode 100644 torchstore/dtensor_utils.py diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index f1fcba2..e3cb6ef 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -14,6 +14,8 @@ import pytest import torch + +import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn import torchstore as ts @@ -25,9 +27,9 @@ get_model_state_dict, get_optimizer_state_dict, ) -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import fully_shard -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Replicate, Shard from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict from torchstore.utils import spawn_actors @@ -39,6 +41,100 @@ MODEL_LINER_LENGTH = 10 +def _setup_process_group(): + """Set up minimal distributed environment for DTensor testing.""" + + if not dist.is_initialized(): + # Set minimal environment variables for single process + import os + + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault( + "MASTER_PORT", "29501" + ) # Different port to avoid conflicts + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + + # Initialize single-process group + dist.init_process_group( + backend="gloo", # CPU backend + rank=0, + world_size=1, + ) + return True + + +def _verify_tensor_references(torchstore_state_dict, flattened_original): + """Utility function to verify TensorReference objects in flattened state dict.""" + for key, original_value in flattened_original.items(): + torchstore_value = torchstore_state_dict.flattened_state_dict[key] + + if isinstance(original_value, torch.Tensor): + if hasattr(original_value, "_local_tensor"): # DTensor check + # DTensor should be converted to TensorReference with tensor_slice + assert isinstance(torchstore_value, TensorReference) + assert ( + torchstore_value.tensor_slice is not None + ), f"DTensor at {key} should have tensor_slice" + assert ( + torchstore_value.device_mesh is not None + ), f"DTensor at {key} should have device_mesh" + assert ( + torchstore_value.placements is not None + ), f"DTensor at {key} should have placements" + + # Verify local tensor metadata + local_tensor = original_value._local_tensor + assert torchstore_value.shape == tuple(local_tensor.shape) + assert torchstore_value.dtype == local_tensor.dtype + else: + # Regular tensor should not have tensor_slice + assert isinstance(torchstore_value, TensorReference) + assert ( + torchstore_value.tensor_slice is None + ), f"Regular tensor at {key} should not have tensor_slice" + assert torchstore_value.shape == tuple(original_value.shape) + assert torchstore_value.dtype == original_value.dtype + + +def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed): + """Utility function to verify reconstructed state dict matches original.""" + for key, original_value in flattened_original.items(): + reconstructed_value = flattened_reconstructed[key] + + if hasattr(original_value, "_local_tensor"): # DTensor check + # Should be reconstructed as DTensor + assert hasattr( + reconstructed_value, "_local_tensor" + ), f"Expected DTensor for {key}" + + # Verify local tensor data matches + assert torch.equal( + original_value._local_tensor, reconstructed_value._local_tensor + ), f"Local tensor data mismatch for {key}" + + # Verify global shape matches + assert ( + original_value.shape == reconstructed_value.shape + ), f"Global shape mismatch for {key}" + + # Verify placements match + assert ( + original_value.placements == reconstructed_value.placements + ), f"Placements mismatch for {key}" + + elif isinstance(original_value, torch.Tensor): + # Regular tensors should remain the same + assert torch.equal( + original_value, reconstructed_value + ), f"Regular tensor mismatch for {key}" + else: + # Non-tensor values should be preserved + assert ( + original_value == reconstructed_value + ), f"Non-tensor value mismatch for {key}" + + class UnitModule(nn.Module): def __init__(self, device: torch.device): super().__init__() @@ -352,45 +448,19 @@ def test_torchstore_state_dict(): torchstore_state_dict.flattened_state_dict.keys() ), "Keys don't match between original and torchstore flattened state dicts" - # 3. For each key, verify tensor conversion and aggregate total size - total_size = 0 - for key in original_flattened.keys(): - original_value = original_flattened[key] - torchstore_value = torchstore_state_dict.flattened_state_dict[key] + # 3. Verify tensor references and calculate total size + _verify_tensor_references(torchstore_state_dict, original_flattened) + # Calculate total size for blob verification + total_size = 0 + for key, original_value in original_flattened.items(): if isinstance(original_value, torch.Tensor): - # Should be converted to TensorReference - assert isinstance( - torchstore_value, TensorReference - ), f"Expected TensorReference for key {key}, got {type(torchstore_value)}" - - # Verify TensorReference properties - assert torchstore_value.shape == tuple( - original_value.shape - ), f"Shape mismatch for key {key}: {torchstore_value.shape} vs {tuple(original_value.shape)}" - assert ( - torchstore_value.dtype == original_value.dtype - ), f"Dtype mismatch for key {key}: {torchstore_value.dtype} vs {original_value.dtype}" - assert ( - torchstore_value.offset >= 0 - ), f"Invalid offset for key {key}: {torchstore_value.offset}" - assert ( - torchstore_value.size > 0 - ), f"Invalid size for key {key}: {torchstore_value.size}" - - # Aggregate total size - expected_tensor_size = ( - original_value.numel() * original_value.element_size() + tensor_to_size = ( + original_value._local_tensor + if hasattr(original_value, "_local_tensor") + else original_value ) - assert ( - torchstore_value.size == expected_tensor_size - ), f"Size mismatch for key {key}: {torchstore_value.size} vs {expected_tensor_size}" - total_size += torchstore_value.size - else: - # Non-tensor values should be preserved as-is - assert ( - torchstore_value == original_value - ), f"Non-tensor value mismatch for key {key}: {torchstore_value} vs {original_value}" + total_size += tensor_to_size.numel() * tensor_to_size.element_size() # Verify tensor blob size matches total size assert ( @@ -494,7 +564,45 @@ def test_torchstore_state_dict_edge_cases(): dtype_dict[key], reconstructed[key] ), f"Mismatch for dtype {key}" - print("✅ test_torchstore_state_dict_edge_cases passed!") + +def test_torchstore_state_dict_with_dtensor(): + """Test TorchStoreStateDict with DTensor support.""" + _setup_process_group() + + # Create single-device mesh (CPU only) + device_mesh = DeviceMesh("cpu", [0]) + + # Create DTensor from local tensor + local_tensor = torch.randn(4, 6, dtype=torch.float32) + dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) + + # Create state dict with DTensor and regular tensor + original_state_dict = { + "regular_tensor": torch.randn(3, 3), + "dtensor": dtensor, + "nested": { + "another_dtensor": DTensor.from_local( + torch.ones(2, 3), device_mesh, [Replicate()] + ), + "metadata": {"test": "value"}, + }, + } + + # Test serialization + torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) + + # Verify DTensor metadata is preserved using utility function + flattened_original, _ = flatten_state_dict(original_state_dict) + _verify_tensor_references(torchstore_state_dict, flattened_original) + + # Test deserialization + reconstructed_state_dict = torchstore_state_dict.to_state_dict() + + # Verify reconstruction using utility function + flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict) + _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed) + + dist.destroy_process_group() if __name__ == "__main__": diff --git a/torchstore/dtensor_utils.py b/torchstore/dtensor_utils.py new file mode 100644 index 0000000..64100d3 --- /dev/null +++ b/torchstore/dtensor_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torch.distributed.tensor import DTensor, Placement +from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset + + +def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": + """ + Create a TensorSlice from a DTensor. + + Args: + dtensor: The DTensor to extract metadata from + + Returns: + TensorSlice containing the distributed tensor metadata + """ + from torchstore.transport.pipe import TensorSlice + + coordinates = dtensor.device_mesh.get_coordinate() + _, offsets = _compute_local_shape_and_global_offset( + dtensor.shape, + mesh_shape=dtensor.device_mesh.shape, + my_coordinate=coordinates, + placements=dtensor.placements, + ) + + return TensorSlice( + offsets=offsets, + coordinates=coordinates, + global_shape=dtensor.shape, + local_shape=dtensor._local_tensor.shape, + mesh_shape=dtensor.device_mesh.shape, + ) + + +def reconstruct_dtensor_from_local_tensor( + local_tensor: torch.Tensor, + tensor_slice: "TensorSlice", + device_mesh: torch.distributed.DeviceMesh, + placements: Tuple[Placement, ...], +) -> DTensor: + """ + Reconstruct a DTensor from local tensor data and TensorSlice metadata. + + Args: + local_tensor: The local tensor shard + tensor_slice: TensorSlice containing distributed metadata + device_mesh: The device mesh for the DTensor + placements: The placements for the DTensor + + Returns: + Reconstructed DTensor + """ + return DTensor.from_local( + local_tensor=local_tensor, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index 2403ea5..f1f2141 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -13,6 +13,10 @@ flatten_state_dict, unflatten_state_dict, ) +from torch.distributed.tensor import DTensor + +from torchstore.dtensor_utils import create_tensor_slice_from_dtensor +from torchstore.transport.pipe import TensorSlice DELIM = "/" MAPPING = "MAPPING" @@ -106,6 +110,11 @@ class TensorReference: dtype: torch.dtype offset: int # Byte offset in the blob size: int # Size in bytes + tensor_slice: Optional[TensorSlice] = None # TensorSlice for DTensor reconstruction + device_mesh: Optional[Any] = None # DeviceMesh for DTensor reconstruction + placements: Optional[Tuple[Any, ...]] = ( + None # Placements for DTensor reconstruction + ) class TorchStoreStateDict: @@ -142,8 +151,26 @@ def from_state_dict(cls, state_dict: Dict[str, Any]) -> "TorchStoreStateDict": current_offset = 0 for key, value in flattened_state_dict.items(): - if isinstance(value, torch.Tensor): - # Calculate size and create reference with correct offset + if isinstance(value, DTensor): + # Handle DTensor: store local tensor and add TensorSlice metadata + local_tensor = value._local_tensor + tensor_size = local_tensor.numel() * local_tensor.element_size() + tensor_slice = create_tensor_slice_from_dtensor(value) + + ref = TensorReference( + shape=tuple(local_tensor.shape), + dtype=local_tensor.dtype, + offset=current_offset, + size=tensor_size, + tensor_slice=tensor_slice, + device_mesh=value.device_mesh, + placements=value.placements, + ) + tensor_list.append((local_tensor, ref)) + modified_flattened_state_dict[key] = ref + current_offset += tensor_size + elif isinstance(value, torch.Tensor): + # Handle regular tensor tensor_size = value.numel() * value.element_size() ref = TensorReference( shape=tuple(value.shape), @@ -181,8 +208,10 @@ def from_state_dict(cls, state_dict: Dict[str, Any]) -> "TorchStoreStateDict": def to_state_dict(self) -> Dict[str, Any]: """ Convert the TorchStoreStateDict back to a state_dict. All TensorReference objects are replaced with - the corresponding tensors from the tensor blob. + the corresponding tensors from the tensor blob. DTensors are reconstructed using stored metadata. """ + from torchstore.dtensor_utils import reconstruct_dtensor_from_local_tensor + # 1. iterate through the flattened state dict, replace TensorReference objects with tensors from the tensor blob reconstructed_flattened_state_dict = {} @@ -204,6 +233,19 @@ def to_state_dict(self) -> Dict[str, Any]: ] byte_view.copy_(tensor_bytes) + # Check if this should be reconstructed as a DTensor + if ( + value.tensor_slice is not None + and value.device_mesh is not None + and value.placements is not None + ): + tensor = reconstruct_dtensor_from_local_tensor( + local_tensor=tensor, + tensor_slice=value.tensor_slice, + device_mesh=value.device_mesh, + placements=value.placements, + ) + reconstructed_flattened_state_dict[key] = tensor else: reconstructed_flattened_state_dict[key] = value diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..1e0f5ef 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -11,8 +11,8 @@ import torch from torch.distributed.tensor import DTensor -from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset +from torchstore.dtensor_utils import create_tensor_slice_from_dtensor from torchstore.transport.buffers import ( MonarchTransportBuffer, rdma_available, @@ -84,21 +84,7 @@ def from_any(cls, value: torch.Tensor | DTensor | None) -> "Request": @classmethod def from_dtensor(cls, dtensor: DTensor) -> "Request": - coordinates = dtensor.device_mesh.get_coordinate() - _, offsets = _compute_local_shape_and_global_offset( - dtensor.shape, - mesh_shape=dtensor.device_mesh.shape, - my_coordinate=coordinates, - placements=dtensor.placements, - ) - - tensor_slice = TensorSlice( - offsets, - coordinates, - dtensor.shape, - dtensor._local_tensor.shape, - dtensor.device_mesh.shape, - ) + tensor_slice = create_tensor_slice_from_dtensor(dtensor) return cls( tensor_val=dtensor._local_tensor, tensor_slice=tensor_slice, From 25ca59b413d0b7c5a25aea1589558d7c662bb8d7 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Tue, 7 Oct 2025 12:37:43 -0700 Subject: [PATCH 6/8] sync --- tests/test_state_dict.py | 35 +++-------------------------------- 1 file changed, 3 insertions(+), 32 deletions(-) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index e3cb6ef..a4de47f 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -29,7 +29,7 @@ ) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import fully_shard -from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import DTensor, Replicate from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict from torchstore.utils import spawn_actors @@ -486,37 +486,8 @@ def test_torchstore_state_dict(): reconstructed_flattened.keys() ), "Flattened keys don't match" - # Compare each tensor/value - for key in original_flattened.keys(): - original_value = original_flattened[key] - reconstructed_value = reconstructed_flattened[key] - - if isinstance(original_value, torch.Tensor): - assert isinstance( - reconstructed_value, torch.Tensor - ), f"Expected tensor for key {key}" - assert ( - original_value.shape == reconstructed_value.shape - ), f"Shape mismatch for key {key}" - assert ( - original_value.dtype == reconstructed_value.dtype - ), f"Dtype mismatch for key {key}" - assert torch.equal( - original_value, reconstructed_value - ), f"Values mismatch for key {key}" - else: - assert ( - original_value == reconstructed_value - ), f"Non-tensor value mismatch for key {key}" - - print("✅ test_torchstore_state_dict passed!") - tensor_count = sum( - 1 - for v in torchstore_state_dict.flattened_state_dict.values() - if isinstance(v, TensorReference) - ) - print(f" Processed {tensor_count} tensors") - print(f" Blob size: {len(blob)} bytes ({len(blob) / 1024:.1f} KB)") + # Verify reconstruction using utility function + _verify_reconstructed_state_dict(original_flattened, reconstructed_flattened) def test_torchstore_state_dict_edge_cases(): From 24d46ac5019f1e5af18d227bd093ac5867175971 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Tue, 7 Oct 2025 12:43:36 -0700 Subject: [PATCH 7/8] sync --- torchstore/state_dict_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torchstore/state_dict_utils.py b/torchstore/state_dict_utils.py index f1f2141..0fdde47 100644 --- a/torchstore/state_dict_utils.py +++ b/torchstore/state_dict_utils.py @@ -110,11 +110,9 @@ class TensorReference: dtype: torch.dtype offset: int # Byte offset in the blob size: int # Size in bytes - tensor_slice: Optional[TensorSlice] = None # TensorSlice for DTensor reconstruction - device_mesh: Optional[Any] = None # DeviceMesh for DTensor reconstruction - placements: Optional[Tuple[Any, ...]] = ( - None # Placements for DTensor reconstruction - ) + tensor_slice: TensorSlice | None = None # TensorSlice for DTensor reconstruction + device_mesh: Any | None = None # DeviceMesh for DTensor reconstruction + placements: Tuple[Any, ...] | None = None # Placements for DTensor reconstruction class TorchStoreStateDict: @@ -129,9 +127,6 @@ def __init__( flattened_state_dict: Dict[str, Any], mapping: Dict[str, Any], ): - """ - Create a TorchStoreStateDict from a tensor blob, flattened state_dict, and mapping. - """ self.tensor_blob = tensor_blob self.flattened_state_dict = flattened_state_dict self.mapping = mapping From 298eb5d76d054f8f9bc4a1dfd962354f25c31660 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Thu, 9 Oct 2025 12:27:18 -0700 Subject: [PATCH 8/8] sync --- tests/test_state_dict.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index a4de47f..0a0d1fe 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -496,15 +496,12 @@ def test_torchstore_state_dict_edge_cases(): # Test empty state dict empty_dict = {} torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict) - assert torchstore_state_dict.flattened_state_dict == {} - assert len(torchstore_state_dict.tensor_blob) == 0 reconstructed = torchstore_state_dict.to_state_dict() assert reconstructed == {} # Test state dict with no tensors no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}} torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors) - assert len(torchstore_state_dict.tensor_blob) == 0 reconstructed = torchstore_state_dict.to_state_dict() assert reconstructed == no_tensors @@ -512,9 +509,6 @@ def test_torchstore_state_dict_edge_cases(): scalar_dict = {"scalar": torch.tensor(3.14159)} torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) # Check flattened state dict has TensorReference - scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] - assert isinstance(scalar_ref, TensorReference) - assert scalar_ref.shape == () # Empty tuple for scalar reconstructed = torchstore_state_dict.to_state_dict() assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"]) @@ -577,5 +571,4 @@ def test_torchstore_state_dict_with_dtensor(): if __name__ == "__main__": - # Run existing tests main(__file__)