Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 277 additions & 3 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
import pytest

import torch

import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn

import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint

from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor import DTensor, Replicate
from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
Expand All @@ -38,6 +41,100 @@
MODEL_LINER_LENGTH = 10


def _setup_process_group():
"""Set up minimal distributed environment for DTensor testing."""

if not dist.is_initialized():
# Set minimal environment variables for single process
import os

os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault(
"MASTER_PORT", "29501"
) # Different port to avoid conflicts
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

# Initialize single-process group
dist.init_process_group(
backend="gloo", # CPU backend
rank=0,
world_size=1,
)
return True


def _verify_tensor_references(torchstore_state_dict, flattened_original):
"""Utility function to verify TensorReference objects in flattened state dict."""
for key, original_value in flattened_original.items():
torchstore_value = torchstore_state_dict.flattened_state_dict[key]

if isinstance(original_value, torch.Tensor):
if hasattr(original_value, "_local_tensor"): # DTensor check
# DTensor should be converted to TensorReference with tensor_slice
assert isinstance(torchstore_value, TensorReference)
assert (
torchstore_value.tensor_slice is not None
), f"DTensor at {key} should have tensor_slice"
assert (
torchstore_value.device_mesh is not None
), f"DTensor at {key} should have device_mesh"
assert (
torchstore_value.placements is not None
), f"DTensor at {key} should have placements"

# Verify local tensor metadata
local_tensor = original_value._local_tensor
assert torchstore_value.shape == tuple(local_tensor.shape)
assert torchstore_value.dtype == local_tensor.dtype
else:
# Regular tensor should not have tensor_slice
assert isinstance(torchstore_value, TensorReference)
assert (
torchstore_value.tensor_slice is None
), f"Regular tensor at {key} should not have tensor_slice"
assert torchstore_value.shape == tuple(original_value.shape)
assert torchstore_value.dtype == original_value.dtype


def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed):
"""Utility function to verify reconstructed state dict matches original."""
for key, original_value in flattened_original.items():
reconstructed_value = flattened_reconstructed[key]

if hasattr(original_value, "_local_tensor"): # DTensor check
# Should be reconstructed as DTensor
assert hasattr(
reconstructed_value, "_local_tensor"
), f"Expected DTensor for {key}"

# Verify local tensor data matches
assert torch.equal(
original_value._local_tensor, reconstructed_value._local_tensor
), f"Local tensor data mismatch for {key}"

# Verify global shape matches
assert (
original_value.shape == reconstructed_value.shape
), f"Global shape mismatch for {key}"

# Verify placements match
assert (
original_value.placements == reconstructed_value.placements
), f"Placements mismatch for {key}"

elif isinstance(original_value, torch.Tensor):
# Regular tensors should remain the same
assert torch.equal(
original_value, reconstructed_value
), f"Regular tensor mismatch for {key}"
else:
# Non-tensor values should be preserved
assert (
original_value == reconstructed_value
), f"Non-tensor value mismatch for {key}"


class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
Expand Down Expand Up @@ -296,5 +393,182 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}"


def test_torchstore_state_dict():
"""Test TorchStoreStateDict class with various tensor types and reconstruction."""

# Create a state dict with various tensor types and shapes
original_state_dict = {
# Scalar tensor (0D)
"scalar": torch.tensor(42.5, dtype=torch.float32),
# 1D tensors with different dtypes
"vector_float": torch.randn(10, dtype=torch.float32),
"vector_int": torch.randint(0, 100, (5,), dtype=torch.int64),
"vector_half": torch.randn(8, dtype=torch.float16),
# 2D tensors with different dtypes
"matrix_float": torch.randn(3, 4, dtype=torch.float32),
"matrix_double": torch.randn(2, 3, dtype=torch.float64),
"matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32),
# Nested structure
"model": {
"layer1": {
"weight": torch.randn(5, 3, dtype=torch.float32),
"bias": torch.randn(5, dtype=torch.float32),
},
"layer2": {
"weight": torch.randn(2, 5, dtype=torch.float32),
"bias": torch.randn(2, dtype=torch.float32),
},
},
# Mixed with non-tensor data
"metadata": {
"epoch": 10,
"learning_rate": 0.001,
"optimizer_state": torch.randn(3, 3, dtype=torch.float32),
},
# List with tensors (note: flattened state dict doesn't preserve list structure)
"layer_weights": [
torch.randn(2, 2, dtype=torch.float32),
torch.tensor(123, dtype=torch.int32),
],
}

# Create TorchStoreStateDict
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

# Verify blob properties
blob = torchstore_state_dict.tensor_blob
assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}"
assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D"

# 1. Flatten original state dict
original_flattened, _ = flatten_state_dict(original_state_dict)

# 2. Verify keys match between original flattened and torchstore flattened state dict
assert set(original_flattened.keys()) == set(
torchstore_state_dict.flattened_state_dict.keys()
), "Keys don't match between original and torchstore flattened state dicts"

# 3. Verify tensor references and calculate total size
_verify_tensor_references(torchstore_state_dict, original_flattened)

# Calculate total size for blob verification
total_size = 0
for key, original_value in original_flattened.items():
if isinstance(original_value, torch.Tensor):
tensor_to_size = (
original_value._local_tensor
if hasattr(original_value, "_local_tensor")
else original_value
)
total_size += tensor_to_size.numel() * tensor_to_size.element_size()

# Verify tensor blob size matches total size
assert (
len(blob) == total_size
), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}"

# Reconstruct the state dict
reconstructed_state_dict = torchstore_state_dict.to_state_dict()

# Compare flattened versions - simpler than recursive comparison
original_flattened, original_mapping = flatten_state_dict(original_state_dict)
reconstructed_flattened, reconstructed_mapping = flatten_state_dict(
reconstructed_state_dict
)

# Verify mappings are identical (structure preserved)
assert (
original_mapping == reconstructed_mapping
), "State dict structure mappings don't match"

# Verify keys match
assert set(original_flattened.keys()) == set(
reconstructed_flattened.keys()
), "Flattened keys don't match"

# Verify reconstruction using utility function
_verify_reconstructed_state_dict(original_flattened, reconstructed_flattened)


def test_torchstore_state_dict_edge_cases():
"""Test edge cases for TorchStoreStateDict."""

# Test empty state dict
empty_dict = {}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict)
reconstructed = torchstore_state_dict.to_state_dict()
assert reconstructed == {}

# Test state dict with no tensors
no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors)
reconstructed = torchstore_state_dict.to_state_dict()
assert reconstructed == no_tensors

# Test scalar tensor edge case
scalar_dict = {"scalar": torch.tensor(3.14159)}
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict)
# Check flattened state dict has TensorReference
reconstructed = torchstore_state_dict.to_state_dict()
assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"])

# Test different dtypes
dtype_dict = {
"bool": torch.tensor([True, False, True]),
"uint8": torch.randint(0, 255, (5,), dtype=torch.uint8),
"int8": torch.randint(-128, 127, (3,), dtype=torch.int8),
"int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16),
"bfloat16": torch.randn(3, dtype=torch.bfloat16),
}

torchstore_state_dict = TorchStoreStateDict.from_state_dict(dtype_dict)
reconstructed = torchstore_state_dict.to_state_dict()

for key in dtype_dict:
assert torch.equal(
dtype_dict[key], reconstructed[key]
), f"Mismatch for dtype {key}"


def test_torchstore_state_dict_with_dtensor():
"""Test TorchStoreStateDict with DTensor support."""
_setup_process_group()

# Create single-device mesh (CPU only)
device_mesh = DeviceMesh("cpu", [0])

# Create DTensor from local tensor
local_tensor = torch.randn(4, 6, dtype=torch.float32)
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test for sharded dtensor (with world size > 1)? I am actually also confused about the expected behavior in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.


# Create state dict with DTensor and regular tensor
original_state_dict = {
"regular_tensor": torch.randn(3, 3),
"dtensor": dtensor,
"nested": {
"another_dtensor": DTensor.from_local(
torch.ones(2, 3), device_mesh, [Replicate()]
),
"metadata": {"test": "value"},
},
}

# Test serialization
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

# Verify DTensor metadata is preserved using utility function
flattened_original, _ = flatten_state_dict(original_state_dict)
_verify_tensor_references(torchstore_state_dict, flattened_original)

# Test deserialization
reconstructed_state_dict = torchstore_state_dict.to_state_dict()

# Verify reconstruction using utility function
flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict)
_verify_reconstructed_state_dict(flattened_original, flattened_reconstructed)

dist.destroy_process_group()


if __name__ == "__main__":
main(__file__)
65 changes: 65 additions & 0 deletions torchstore/dtensor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torch.distributed.tensor import DTensor, Placement
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset


def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice":
from torchstore.transport.pipe import TensorSlice
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> TensorSlice:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the response to the other comment about import ordering.

"""
Create a TensorSlice from a DTensor.
Args:
dtensor: The DTensor to extract metadata from
Returns:
TensorSlice containing the distributed tensor metadata
"""
from torchstore.transport.pipe import TensorSlice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a particular reason to avoid import this on the file level?
if not, move import to top of file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there's a circular dependency where

pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor

Maybe we should put TensorSlice definition into dtensor_util.py module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would from __future__ import annotations fix this?
If not then just leave it as is.


coordinates = dtensor.device_mesh.get_coordinate()
_, offsets = _compute_local_shape_and_global_offset(
dtensor.shape,
mesh_shape=dtensor.device_mesh.shape,
my_coordinate=coordinates,
placements=dtensor.placements,
)

return TensorSlice(
offsets=offsets,
coordinates=coordinates,
global_shape=dtensor.shape,
local_shape=dtensor._local_tensor.shape,
mesh_shape=dtensor.device_mesh.shape,
)


def reconstruct_dtensor_from_local_tensor(
local_tensor: torch.Tensor,
tensor_slice: "TensorSlice",
device_mesh: torch.distributed.DeviceMesh,
placements: Tuple[Placement, ...],
) -> DTensor:
"""
Reconstruct a DTensor from local tensor data and TensorSlice metadata.
Args:
local_tensor: The local tensor shard
tensor_slice: TensorSlice containing distributed metadata
device_mesh: The device mesh for the DTensor
placements: The placements for the DTensor
Returns:
Reconstructed DTensor
"""
return DTensor.from_local(
local_tensor=local_tensor,
device_mesh=device_mesh,
placements=placements,
)
Loading