Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 283 additions & 4 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,27 @@
import pytest

import torch

import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn

import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint

from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor import DTensor, Replicate
from torchstore.state_dict_utils import (
TensorReference,
TORCHSTORE_TSSD_ENABLED_FLAG,
TorchStoreStateDict,
)
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
Expand All @@ -38,6 +45,100 @@
MODEL_LINER_LENGTH = 10


def _setup_process_group():
"""Set up minimal distributed environment for DTensor testing."""

if not dist.is_initialized():
# Set minimal environment variables for single process
import os

os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault(
"MASTER_PORT", "29501"
) # Different port to avoid conflicts
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

# Initialize single-process group
dist.init_process_group(
backend="gloo", # CPU backend
rank=0,
world_size=1,
)
return True


def _verify_tensor_references(torchstore_state_dict, flattened_original):
"""Utility function to verify TensorReference objects in flattened state dict."""
for key, original_value in flattened_original.items():
torchstore_value = torchstore_state_dict.flattened_state_dict[key]

if isinstance(original_value, torch.Tensor):
if hasattr(original_value, "_local_tensor"): # DTensor check
# DTensor should be converted to TensorReference with tensor_slice
assert isinstance(torchstore_value, TensorReference)
assert (
torchstore_value.tensor_slice is not None
), f"DTensor at {key} should have tensor_slice"
assert (
torchstore_value.device_mesh is not None
), f"DTensor at {key} should have device_mesh"
assert (
torchstore_value.placements is not None
), f"DTensor at {key} should have placements"

# Verify local tensor metadata
local_tensor = original_value._local_tensor
assert torchstore_value.shape == tuple(local_tensor.shape)
assert torchstore_value.dtype == local_tensor.dtype
else:
# Regular tensor should not have tensor_slice
assert isinstance(torchstore_value, TensorReference)
assert (
torchstore_value.tensor_slice is None
), f"Regular tensor at {key} should not have tensor_slice"
assert torchstore_value.shape == tuple(original_value.shape)
assert torchstore_value.dtype == original_value.dtype


def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed):
"""Utility function to verify reconstructed state dict matches original."""
for key, original_value in flattened_original.items():
reconstructed_value = flattened_reconstructed[key]

if hasattr(original_value, "_local_tensor"): # DTensor check
# Should be reconstructed as DTensor
assert hasattr(
reconstructed_value, "_local_tensor"
), f"Expected DTensor for {key}"

# Verify local tensor data matches
assert torch.equal(
original_value._local_tensor, reconstructed_value._local_tensor
), f"Local tensor data mismatch for {key}"

# Verify global shape matches
assert (
original_value.shape == reconstructed_value.shape
), f"Global shape mismatch for {key}"

# Verify placements match
assert (
original_value.placements == reconstructed_value.placements
), f"Placements mismatch for {key}"

elif isinstance(original_value, torch.Tensor):
# Regular tensors should remain the same
assert torch.equal(
original_value, reconstructed_value
), f"Regular tensor mismatch for {key}"
else:
# Non-tensor values should be preserved
assert (
original_value == reconstructed_value
), f"Non-tensor value mismatch for {key}"


class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
Expand Down Expand Up @@ -167,8 +268,9 @@ async def do_get(self):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_state_dict(strategy_params, use_rdma):
async def test_state_dict_lky(strategy_params, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
os.environ[TORCHSTORE_TSSD_ENABLED_FLAG] = "1"

class Trainer(Actor):
# Monarch RDMA does not work outside of an actor, so we need
Expand Down Expand Up @@ -296,5 +398,182 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}"


def test_torchstore_state_dict():
"""Test TorchStoreStateDict class with various tensor types and reconstruction."""

# Create a state dict with various tensor types and shapes
original_state_dict = {
# Scalar tensor (0D)
"scalar": torch.tensor(42.5, dtype=torch.float32),
# 1D tensors with different dtypes
"vector_float": torch.randn(10, dtype=torch.float32),
"vector_int": torch.randint(0, 100, (5,), dtype=torch.int64),
"vector_half": torch.randn(8, dtype=torch.float16),
# 2D tensors with different dtypes
"matrix_float": torch.randn(3, 4, dtype=torch.float32),
"matrix_double": torch.randn(2, 3, dtype=torch.float64),
"matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32),
# Nested structure
"model": {
"layer1": {
"weight": torch.randn(5, 3, dtype=torch.float32),
"bias": torch.randn(5, dtype=torch.float32),
},
"layer2": {
"weight": torch.randn(2, 5, dtype=torch.float32),
"bias": torch.randn(2, dtype=torch.float32),
},
},
# Mixed with non-tensor data
"metadata": {
"epoch": 10,
"learning_rate": 0.001,
"optimizer_state": torch.randn(3, 3, dtype=torch.float32),
},
# List with tensors (note: flattened state dict doesn't preserve list structure)
"layer_weights": [
torch.randn(2, 2, dtype=torch.float32),
torch.tensor(123, dtype=torch.int32),
],
}

# Create TorchStoreStateDict
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

# Verify blob properties
blob = torchstore_state_dict.tensor_blob
assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}"
assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D"

# 1. Flatten original state dict
original_flattened, _ = flatten_state_dict(original_state_dict)

# 2. Verify keys match between original flattened and torchstore flattened state dict
assert set(original_flattened.keys()) == set(
torchstore_state_dict.flattened_state_dict.keys()
), "Keys don't match between original and torchstore flattened state dicts"

# 3. Verify tensor references and calculate total size
_verify_tensor_references(torchstore_state_dict, original_flattened)

# Calculate total size for blob verification
total_size = 0
for key, original_value in original_flattened.items():
if isinstance(original_value, torch.Tensor):
tensor_to_size = (
original_value._local_tensor
if hasattr(original_value, "_local_tensor")
else original_value
)
total_size += tensor_to_size.numel() * tensor_to_size.element_size()

# Verify tensor blob size matches total size
assert (
len(blob) == total_size
), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}"

# Reconstruct the state dict
reconstructed_state_dict = torchstore_state_dict.to_state_dict()

# Compare flattened versions - simpler than recursive comparison
original_flattened, original_mapping = flatten_state_dict(original_state_dict)
reconstructed_flattened, reconstructed_mapping = flatten_state_dict(
reconstructed_state_dict
)

# Verify mappings are identical (structure preserved)
assert (
original_mapping == reconstructed_mapping
), "State dict structure mappings don't match"

# Verify keys match
assert set(original_flattened.keys()) == set(
reconstructed_flattened.keys()
), "Flattened keys don't match"

# Verify reconstruction using utility function
_verify_reconstructed_state_dict(original_flattened, reconstructed_flattened)


def test_torchstore_state_dict_edge_cases():
"""Test edge cases for TorchStoreStateDict."""

# Test empty state dict
empty_dict = {}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict)
reconstructed = torchstore_state_dict.to_state_dict()
assert reconstructed == {}

# Test state dict with no tensors
no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors)
reconstructed = torchstore_state_dict.to_state_dict()
assert reconstructed == no_tensors

# Test scalar tensor edge case
scalar_dict = {"scalar": torch.tensor(3.14159)}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict)
# Check flattened state dict has TensorReference
reconstructed = torchstore_state_dict.to_state_dict()
assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"])

# Test different dtypes
dtype_dict = {
"bool": torch.tensor([True, False, True]),
"uint8": torch.randint(0, 255, (5,), dtype=torch.uint8),
"int8": torch.randint(-128, 127, (3,), dtype=torch.int8),
"int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16),
"bfloat16": torch.randn(3, dtype=torch.bfloat16),
}

torchstore_state_dict = TorchStoreStateDict.from_state_dict(dtype_dict)
reconstructed = torchstore_state_dict.to_state_dict()

for key in dtype_dict:
assert torch.equal(
dtype_dict[key], reconstructed[key]
), f"Mismatch for dtype {key}"


def test_torchstore_state_dict_with_dtensor():
"""Test TorchStoreStateDict with DTensor support."""
_setup_process_group()

# Create single-device mesh (CPU only)
device_mesh = DeviceMesh("cpu", [0])

# Create DTensor from local tensor
local_tensor = torch.randn(4, 6, dtype=torch.float32)
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()])

# Create state dict with DTensor and regular tensor
original_state_dict = {
"regular_tensor": torch.randn(3, 3),
"dtensor": dtensor,
"nested": {
"another_dtensor": DTensor.from_local(
torch.ones(2, 3), device_mesh, [Replicate()]
),
"metadata": {"test": "value"},
},
}

# Test serialization
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

# Verify DTensor metadata is preserved using utility function
flattened_original, _ = flatten_state_dict(original_state_dict)
_verify_tensor_references(torchstore_state_dict, flattened_original)

# Test deserialization
reconstructed_state_dict = torchstore_state_dict.to_state_dict()

# Verify reconstruction using utility function
flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict)
_verify_reconstructed_state_dict(flattened_original, flattened_reconstructed)

dist.destroy_process_group()


if __name__ == "__main__":
main(__file__)
14 changes: 13 additions & 1 deletion torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchstore.controller import ObjectType
from torchstore.logging import LatencyTracker
from torchstore.state_dict_utils import DELIM, FLATTENED_STATE_DICT, get_state_dict_key
from torchstore.transport import Pipe, Request, TensorSlice
from torchstore.utils import assemble_global_tensor, get_local_tensor

Expand Down Expand Up @@ -54,7 +55,18 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]):
await pipe.put_to_storage_volume(key, request)
latency_tracker.track_step("put_to_storage_volume")

await self._controller.notify_put.call(key, request.meta_only(), volume_id)
if key.endswith(FLATTENED_STATE_DICT):
state_dict_key = get_state_dict_key(key)
for flattened_key in value.keys():
flattened_key = f"{state_dict_key}{DELIM}{flattened_key}"
await self._controller.notify_put.call(
flattened_key,
request.meta_only(),
volume_id,
)
else:
await self._controller.notify_put.call(key, request.meta_only(), volume_id)

latency_tracker.track_step("notify_put")
latency_tracker.track_e2e()

Expand Down
Loading