Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torchstore.state_dict_utils import (
generate_tensor_blob,
reconstruct_state_dict_from_tensor_blob,
TensorReference,
)
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
Expand Down Expand Up @@ -296,5 +301,188 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}"


def test_generate_tensor_blob():
"""Test generate_tensor_blob with various tensor types and reconstruction."""

# Create a state dict with various tensor types and shapes
original_state_dict = {
# Scalar tensor (0D)
"scalar": torch.tensor(42.5, dtype=torch.float32),
# 1D tensors with different dtypes
"vector_float": torch.randn(10, dtype=torch.float32),
"vector_int": torch.randint(0, 100, (5,), dtype=torch.int64),
"vector_half": torch.randn(8, dtype=torch.float16),
# 2D tensors with different dtypes
"matrix_float": torch.randn(3, 4, dtype=torch.float32),
"matrix_double": torch.randn(2, 3, dtype=torch.float64),
"matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32),
# Nested structure
"model": {
"layer1": {
"weight": torch.randn(5, 3, dtype=torch.float32),
"bias": torch.randn(5, dtype=torch.float32),
},
"layer2": {
"weight": torch.randn(2, 5, dtype=torch.float32),
"bias": torch.randn(2, dtype=torch.float32),
},
},
# Mixed with non-tensor data
"metadata": {
"epoch": 10,
"learning_rate": 0.001,
"optimizer_state": torch.randn(3, 3, dtype=torch.float32),
},
# List with tensors
"tensor_list": [
torch.randn(2, 2, dtype=torch.float32),
torch.tensor(123, dtype=torch.int32),
],
}

# Generate tensor blob
modified_state_dict, blob = generate_tensor_blob(original_state_dict)

# Verify blob properties
assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}"
assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D"

# Calculate expected blob size
expected_size = 0

def calculate_expected_size(obj):
nonlocal expected_size
if isinstance(obj, torch.Tensor):
expected_size += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
calculate_expected_size(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
calculate_expected_size(item)

calculate_expected_size(original_state_dict)
assert (
len(blob) == expected_size
), f"Expected blob size {expected_size}, got {len(blob)}"

# Verify that tensors are replaced with TensorReference objects
def verify_tensor_references(obj, path=""):
if isinstance(obj, TensorReference):
assert obj.shape is not None, f"TensorReference at {path} missing shape"
assert obj.dtype is not None, f"TensorReference at {path} missing dtype"
assert (
obj.offset >= 0
), f"TensorReference at {path} has invalid offset {obj.offset}"
assert (
obj.size > 0
), f"TensorReference at {path} has invalid size {obj.size}"
elif isinstance(obj, dict):
for k, v in obj.items():
verify_tensor_references(v, f"{path}.{k}" if path else k)
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
verify_tensor_references(item, f"{path}[{i}]")
elif isinstance(obj, torch.Tensor):
raise AssertionError(f"Found unreplaced tensor at {path}")

verify_tensor_references(modified_state_dict)

# Verify that non-tensor data is preserved
assert modified_state_dict["metadata"]["epoch"] == 10
assert modified_state_dict["metadata"]["learning_rate"] == 0.001

# Reconstruct the state dict
reconstructed_state_dict = reconstruct_state_dict_from_tensor_blob(
modified_state_dict, blob
)

# Verify reconstruction matches original
def compare_state_dicts(original, reconstructed, path=""):
if isinstance(original, torch.Tensor):
assert isinstance(reconstructed, torch.Tensor), f"Expected tensor at {path}"
assert (
original.shape == reconstructed.shape
), f"Shape mismatch at {path}: {original.shape} vs {reconstructed.shape}"
assert (
original.dtype == reconstructed.dtype
), f"Dtype mismatch at {path}: {original.dtype} vs {reconstructed.dtype}"
assert torch.equal(original, reconstructed), f"Values mismatch at {path}"
elif isinstance(original, dict):
assert isinstance(reconstructed, dict), f"Expected dict at {path}"
assert set(original.keys()) == set(
reconstructed.keys()
), f"Key mismatch at {path}"
for k in original.keys():
compare_state_dicts(
original[k], reconstructed[k], f"{path}.{k}" if path else k
)
elif isinstance(original, (list, tuple)):
assert type(original) == type(reconstructed), f"Type mismatch at {path}"
assert len(original) == len(reconstructed), f"Length mismatch at {path}"
for i, (orig_item, recon_item) in enumerate(zip(original, reconstructed)):
compare_state_dicts(orig_item, recon_item, f"{path}[{i}]")
else:
assert (
original == reconstructed
), f"Value mismatch at {path}: {original} vs {reconstructed}"

compare_state_dicts(original_state_dict, reconstructed_state_dict)

print("✅ test_generate_tensor_blob passed!")
print(
f" Processed {len([x for x in str(modified_state_dict) if 'TensorReference' in str(x)])} tensors"
)
print(f" Blob size: {len(blob)} bytes ({len(blob) / 1024:.1f} KB)")


def test_generate_tensor_blob_edge_cases():
"""Test edge cases for generate_tensor_blob."""

# Test empty state dict
empty_dict = {}
modified, blob = generate_tensor_blob(empty_dict)
assert modified == {}
assert len(blob) == 0
reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob)
assert reconstructed == {}

# Test state dict with no tensors
no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}}
modified, blob = generate_tensor_blob(no_tensors)
assert modified == no_tensors
assert len(blob) == 0
reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob)
assert reconstructed == no_tensors

# Test scalar tensor edge case
scalar_dict = {"scalar": torch.tensor(3.14159)}
modified, blob = generate_tensor_blob(scalar_dict)
assert isinstance(modified["scalar"], TensorReference)
assert modified["scalar"].shape == () # Empty tuple for scalar
reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob)
assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"])

# Test different dtypes
dtype_dict = {
"bool": torch.tensor([True, False, True]),
"uint8": torch.randint(0, 255, (5,), dtype=torch.uint8),
"int8": torch.randint(-128, 127, (3,), dtype=torch.int8),
"int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16),
"bfloat16": torch.randn(3, dtype=torch.bfloat16),
}

modified, blob = generate_tensor_blob(dtype_dict)
reconstructed = reconstruct_state_dict_from_tensor_blob(modified, blob)

for key in dtype_dict:
assert torch.equal(
dtype_dict[key], reconstructed[key]
), f"Mismatch for dtype {key}"

print("✅ test_generate_tensor_blob_edge_cases passed!")


if __name__ == "__main__":
# Run existing tests
main(__file__)
122 changes: 121 additions & 1 deletion torchstore/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from logging import getLogger
from typing import Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch.distributed.checkpoint._nested_dict import (
Expand Down Expand Up @@ -95,3 +96,122 @@ def _state_dict_size(state_dict):

size += tensor.numel() * tensor.element_size()
return size // (1024 * 1024)


@dataclass
class TensorReference:
"""Metadata for a tensor in a tensor blob"""

shape: Tuple[int, ...]
dtype: torch.dtype
offset: int # Byte offset in the blob
size: int # Size in bytes


def generate_tensor_blob(state_dict: Dict[str, Any]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead use flatten_state_dict instead of making this recursive?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this a class method of a "TorchStoreStateDict", or similar?

Then we can do things like:

torchstore_sd = TorchStoreStateDict.from_state_dict(original_state_dict)
torchstore_sd.to_state_dict()

and also store any necessary data as objects in the state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""
Extract all tensors from state_dict and create a blob. Replace the tensors
with corresponding references and returns a state_dict with only tensor references,
and the tensor blob.

Args:
state_dict: Dictionary that may contain tensors at any level

Returns:
- Modified dictionary with tensors replaced by TensorReference objects
- 1D uint8 tensor blob containing all serialized tensor data
"""

def _extract_recursive(
obj: Dict[str, Any],
tensor_list: List[Tuple[torch.Tensor, TensorReference]],
path: str = "",
):
"""Recursively extract tensors and replace with TensorReference objects"""
if isinstance(obj, torch.Tensor):
# Create placeholder reference (offset will be filled later)
ref = TensorReference(
shape=tuple(obj.shape),
dtype=obj.dtype,
offset=-1, # Will be updated when building blob
size=obj.numel() * obj.element_size(),
)
tensor_list.append((obj, ref))
return ref # Replace tensor with TensorReference
elif isinstance(obj, dict):
return {
k: _extract_recursive(v, tensor_list, f"{path}.{k}")
for k, v in obj.items()
}
elif isinstance(obj, (list, tuple)):
return type(obj)(
_extract_recursive(item, tensor_list, f"{path}[{i}]")
for i, item in enumerate(obj)
)
else:
return obj # Non-tensor data stays as-is

tensor_list: List[Tuple[torch.Tensor, TensorReference]] = []

modified_state_dict = _extract_recursive(state_dict, tensor_list)

if not tensor_list:
return modified_state_dict, torch.empty(0, dtype=torch.uint8)

# Calculate total size and update offsets
current_offset = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a _state_dict_size function in state dict utils

Copy link
Contributor Author

@kaiyuan-li kaiyuan-li Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_state_dict_size calculates approximate size return size << 20.

for tensor, ref in tensor_list:
ref.offset = current_offset
current_offset += ref.size

blob = torch.empty(current_offset, dtype=torch.uint8)

# Copy tensor data using your efficient approach
for tensor, ref in tensor_list:
# Handle scalar tensors
tensor_cpu = tensor.detach().cpu()
if tensor_cpu.dim() == 0:
tensor_cpu = tensor_cpu.unsqueeze(0)

byte_view = tensor_cpu.view(torch.uint8).flatten()

# Copy to blob
blob[ref.offset : ref.offset + ref.size] = byte_view

return modified_state_dict, blob


def reconstruct_state_dict_from_tensor_blob(
state_dict_with_tensor_refs: Dict[str, Any], blob: torch.Tensor
) -> Dict[str, Any]:
"""
Reconstruct a state_dict which only contains tensor references by
reconstructing the tensors using the tensor blob and the tensor references.
Returns the reconstructed state dict.
"""

def _reconstruct_recursive(obj):
if isinstance(obj, TensorReference):
# Pre-allocate tensor with correct shape and dtype (TorchStore approach)
tensor = torch.empty(obj.shape, dtype=obj.dtype)

# Get byte view of the allocated tensor
if tensor.dim() == 0:
tensor_unsqueezed = tensor.unsqueeze(0)
byte_view = tensor_unsqueezed.view(torch.uint8).flatten()
else:
byte_view = tensor.view(torch.uint8).flatten()

# Copy bytes from blob into tensor's byte view
tensor_bytes = blob[obj.offset : obj.offset + obj.size]
byte_view.copy_(tensor_bytes)

return tensor
elif isinstance(obj, dict):
return {k: _reconstruct_recursive(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_reconstruct_recursive(item) for item in obj)
else:
return obj

return _reconstruct_recursive(state_dict_with_tensor_refs)