-
Notifications
You must be signed in to change notification settings - Fork 5
State dict serialization #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
9567b1f
43427f6
c1da899
dce528a
1b35e01
f68ce3f
25ca59b
24d46ac
298eb5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,20 +14,23 @@ | |
import pytest | ||
|
||
import torch | ||
|
||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as dcp | ||
import torch.nn as nn | ||
|
||
import torchstore as ts | ||
|
||
from monarch.actor import Actor, current_rank, endpoint | ||
|
||
from torch.distributed.checkpoint._nested_dict import flatten_state_dict | ||
from torch.distributed.checkpoint.state_dict import ( | ||
get_model_state_dict, | ||
get_optimizer_state_dict, | ||
) | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||
from torch.distributed.fsdp import fully_shard | ||
from torch.distributed.tensor import DTensor | ||
from torch.distributed.tensor import DTensor, Replicate | ||
from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict | ||
from torchstore.utils import spawn_actors | ||
|
||
from .utils import main, transport_plus_strategy_params | ||
|
@@ -38,6 +41,100 @@ | |
MODEL_LINER_LENGTH = 10 | ||
|
||
|
||
def _setup_process_group(): | ||
"""Set up minimal distributed environment for DTensor testing.""" | ||
|
||
if not dist.is_initialized(): | ||
# Set minimal environment variables for single process | ||
import os | ||
|
||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1") | ||
os.environ.setdefault( | ||
"MASTER_PORT", "29501" | ||
) # Different port to avoid conflicts | ||
os.environ.setdefault("RANK", "0") | ||
os.environ.setdefault("WORLD_SIZE", "1") | ||
|
||
# Initialize single-process group | ||
dist.init_process_group( | ||
backend="gloo", # CPU backend | ||
rank=0, | ||
world_size=1, | ||
) | ||
return True | ||
|
||
|
||
def _verify_tensor_references(torchstore_state_dict, flattened_original): | ||
"""Utility function to verify TensorReference objects in flattened state dict.""" | ||
for key, original_value in flattened_original.items(): | ||
torchstore_value = torchstore_state_dict.flattened_state_dict[key] | ||
|
||
if isinstance(original_value, torch.Tensor): | ||
if hasattr(original_value, "_local_tensor"): # DTensor check | ||
# DTensor should be converted to TensorReference with tensor_slice | ||
assert isinstance(torchstore_value, TensorReference) | ||
assert ( | ||
torchstore_value.tensor_slice is not None | ||
), f"DTensor at {key} should have tensor_slice" | ||
assert ( | ||
torchstore_value.device_mesh is not None | ||
), f"DTensor at {key} should have device_mesh" | ||
assert ( | ||
torchstore_value.placements is not None | ||
), f"DTensor at {key} should have placements" | ||
|
||
# Verify local tensor metadata | ||
local_tensor = original_value._local_tensor | ||
assert torchstore_value.shape == tuple(local_tensor.shape) | ||
assert torchstore_value.dtype == local_tensor.dtype | ||
else: | ||
# Regular tensor should not have tensor_slice | ||
assert isinstance(torchstore_value, TensorReference) | ||
assert ( | ||
torchstore_value.tensor_slice is None | ||
), f"Regular tensor at {key} should not have tensor_slice" | ||
assert torchstore_value.shape == tuple(original_value.shape) | ||
assert torchstore_value.dtype == original_value.dtype | ||
|
||
|
||
def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed): | ||
"""Utility function to verify reconstructed state dict matches original.""" | ||
for key, original_value in flattened_original.items(): | ||
reconstructed_value = flattened_reconstructed[key] | ||
|
||
if hasattr(original_value, "_local_tensor"): # DTensor check | ||
# Should be reconstructed as DTensor | ||
assert hasattr( | ||
reconstructed_value, "_local_tensor" | ||
), f"Expected DTensor for {key}" | ||
|
||
# Verify local tensor data matches | ||
assert torch.equal( | ||
original_value._local_tensor, reconstructed_value._local_tensor | ||
), f"Local tensor data mismatch for {key}" | ||
|
||
# Verify global shape matches | ||
assert ( | ||
original_value.shape == reconstructed_value.shape | ||
), f"Global shape mismatch for {key}" | ||
|
||
# Verify placements match | ||
assert ( | ||
original_value.placements == reconstructed_value.placements | ||
), f"Placements mismatch for {key}" | ||
|
||
elif isinstance(original_value, torch.Tensor): | ||
# Regular tensors should remain the same | ||
assert torch.equal( | ||
original_value, reconstructed_value | ||
), f"Regular tensor mismatch for {key}" | ||
else: | ||
# Non-tensor values should be preserved | ||
assert ( | ||
original_value == reconstructed_value | ||
), f"Non-tensor value mismatch for {key}" | ||
|
||
|
||
class UnitModule(nn.Module): | ||
def __init__(self, device: torch.device): | ||
super().__init__() | ||
|
@@ -296,5 +393,189 @@ def _assert_equal_state_dict(state_dict1, state_dict2): | |
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}" | ||
|
||
|
||
def test_torchstore_state_dict(): | ||
"""Test TorchStoreStateDict class with various tensor types and reconstruction.""" | ||
|
||
# Create a state dict with various tensor types and shapes | ||
original_state_dict = { | ||
# Scalar tensor (0D) | ||
"scalar": torch.tensor(42.5, dtype=torch.float32), | ||
# 1D tensors with different dtypes | ||
"vector_float": torch.randn(10, dtype=torch.float32), | ||
"vector_int": torch.randint(0, 100, (5,), dtype=torch.int64), | ||
"vector_half": torch.randn(8, dtype=torch.float16), | ||
# 2D tensors with different dtypes | ||
"matrix_float": torch.randn(3, 4, dtype=torch.float32), | ||
"matrix_double": torch.randn(2, 3, dtype=torch.float64), | ||
"matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32), | ||
# Nested structure | ||
"model": { | ||
"layer1": { | ||
"weight": torch.randn(5, 3, dtype=torch.float32), | ||
"bias": torch.randn(5, dtype=torch.float32), | ||
}, | ||
"layer2": { | ||
"weight": torch.randn(2, 5, dtype=torch.float32), | ||
"bias": torch.randn(2, dtype=torch.float32), | ||
}, | ||
}, | ||
# Mixed with non-tensor data | ||
"metadata": { | ||
"epoch": 10, | ||
"learning_rate": 0.001, | ||
"optimizer_state": torch.randn(3, 3, dtype=torch.float32), | ||
}, | ||
# List with tensors (note: flattened state dict doesn't preserve list structure) | ||
"layer_weights": [ | ||
torch.randn(2, 2, dtype=torch.float32), | ||
torch.tensor(123, dtype=torch.int32), | ||
], | ||
} | ||
|
||
# Create TorchStoreStateDict | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) | ||
|
||
# Verify blob properties | ||
blob = torchstore_state_dict.tensor_blob | ||
assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}" | ||
assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D" | ||
|
||
# 1. Flatten original state dict | ||
original_flattened, _ = flatten_state_dict(original_state_dict) | ||
|
||
# 2. Verify keys match between original flattened and torchstore flattened state dict | ||
assert set(original_flattened.keys()) == set( | ||
torchstore_state_dict.flattened_state_dict.keys() | ||
), "Keys don't match between original and torchstore flattened state dicts" | ||
|
||
# 3. Verify tensor references and calculate total size | ||
_verify_tensor_references(torchstore_state_dict, original_flattened) | ||
|
||
# Calculate total size for blob verification | ||
total_size = 0 | ||
for key, original_value in original_flattened.items(): | ||
if isinstance(original_value, torch.Tensor): | ||
tensor_to_size = ( | ||
original_value._local_tensor | ||
if hasattr(original_value, "_local_tensor") | ||
else original_value | ||
) | ||
total_size += tensor_to_size.numel() * tensor_to_size.element_size() | ||
|
||
# Verify tensor blob size matches total size | ||
assert ( | ||
len(blob) == total_size | ||
), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}" | ||
|
||
# Reconstruct the state dict | ||
reconstructed_state_dict = torchstore_state_dict.to_state_dict() | ||
|
||
# Compare flattened versions - simpler than recursive comparison | ||
original_flattened, original_mapping = flatten_state_dict(original_state_dict) | ||
reconstructed_flattened, reconstructed_mapping = flatten_state_dict( | ||
reconstructed_state_dict | ||
) | ||
|
||
# Verify mappings are identical (structure preserved) | ||
assert ( | ||
original_mapping == reconstructed_mapping | ||
), "State dict structure mappings don't match" | ||
|
||
# Verify keys match | ||
assert set(original_flattened.keys()) == set( | ||
reconstructed_flattened.keys() | ||
), "Flattened keys don't match" | ||
|
||
# Verify reconstruction using utility function | ||
_verify_reconstructed_state_dict(original_flattened, reconstructed_flattened) | ||
|
||
|
||
def test_torchstore_state_dict_edge_cases(): | ||
"""Test edge cases for TorchStoreStateDict.""" | ||
|
||
# Test empty state dict | ||
empty_dict = {} | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(empty_dict) | ||
assert torchstore_state_dict.flattened_state_dict == {} | ||
assert len(torchstore_state_dict.tensor_blob) == 0 | ||
reconstructed = torchstore_state_dict.to_state_dict() | ||
assert reconstructed == {} | ||
|
||
# Test state dict with no tensors | ||
no_tensors = {"a": 1, "b": {"c": "hello", "d": [1, 2, 3]}} | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(no_tensors) | ||
assert len(torchstore_state_dict.tensor_blob) == 0 | ||
reconstructed = torchstore_state_dict.to_state_dict() | ||
|
||
assert reconstructed == no_tensors | ||
|
||
# Test scalar tensor edge case | ||
scalar_dict = {"scalar": torch.tensor(3.14159)} | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) | ||
# Check flattened state dict has TensorReference | ||
scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] | ||
|
||
assert isinstance(scalar_ref, TensorReference) | ||
assert scalar_ref.shape == () # Empty tuple for scalar | ||
reconstructed = torchstore_state_dict.to_state_dict() | ||
assert torch.equal(scalar_dict["scalar"], reconstructed["scalar"]) | ||
|
||
# Test different dtypes | ||
dtype_dict = { | ||
"bool": torch.tensor([True, False, True]), | ||
"uint8": torch.randint(0, 255, (5,), dtype=torch.uint8), | ||
"int8": torch.randint(-128, 127, (3,), dtype=torch.int8), | ||
"int16": torch.randint(-1000, 1000, (4,), dtype=torch.int16), | ||
"bfloat16": torch.randn(3, dtype=torch.bfloat16), | ||
} | ||
|
||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(dtype_dict) | ||
reconstructed = torchstore_state_dict.to_state_dict() | ||
|
||
for key in dtype_dict: | ||
assert torch.equal( | ||
dtype_dict[key], reconstructed[key] | ||
), f"Mismatch for dtype {key}" | ||
|
||
|
||
def test_torchstore_state_dict_with_dtensor(): | ||
"""Test TorchStoreStateDict with DTensor support.""" | ||
_setup_process_group() | ||
|
||
# Create single-device mesh (CPU only) | ||
device_mesh = DeviceMesh("cpu", [0]) | ||
|
||
# Create DTensor from local tensor | ||
local_tensor = torch.randn(4, 6, dtype=torch.float32) | ||
dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a test for sharded dtensor (with world size > 1)? I am actually also confused about the expected behavior in this case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part. |
||
|
||
# Create state dict with DTensor and regular tensor | ||
original_state_dict = { | ||
"regular_tensor": torch.randn(3, 3), | ||
"dtensor": dtensor, | ||
"nested": { | ||
"another_dtensor": DTensor.from_local( | ||
torch.ones(2, 3), device_mesh, [Replicate()] | ||
), | ||
"metadata": {"test": "value"}, | ||
}, | ||
} | ||
|
||
# Test serialization | ||
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict) | ||
|
||
# Verify DTensor metadata is preserved using utility function | ||
flattened_original, _ = flatten_state_dict(original_state_dict) | ||
_verify_tensor_references(torchstore_state_dict, flattened_original) | ||
|
||
# Test deserialization | ||
reconstructed_state_dict = torchstore_state_dict.to_state_dict() | ||
|
||
# Verify reconstruction using utility function | ||
flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict) | ||
_verify_reconstructed_state_dict(flattened_original, flattened_reconstructed) | ||
|
||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
# Run existing tests | ||
main(__file__) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,65 @@ | ||||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||
# All rights reserved. | ||||||||
# | ||||||||
# This source code is licensed under the BSD-style license found in the | ||||||||
# LICENSE file in the root directory of this source tree. | ||||||||
|
||||||||
from typing import Tuple | ||||||||
|
||||||||
import torch | ||||||||
from torch.distributed.tensor import DTensor, Placement | ||||||||
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset | ||||||||
|
||||||||
|
||||||||
def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the response to the other comment about import ordering. |
||||||||
""" | ||||||||
Create a TensorSlice from a DTensor. | ||||||||
Args: | ||||||||
dtensor: The DTensor to extract metadata from | ||||||||
Returns: | ||||||||
TensorSlice containing the distributed tensor metadata | ||||||||
""" | ||||||||
from torchstore.transport.pipe import TensorSlice | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a particular reason to avoid import this on the file level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So there's a circular dependency where
Maybe we should put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would |
||||||||
|
||||||||
coordinates = dtensor.device_mesh.get_coordinate() | ||||||||
_, offsets = _compute_local_shape_and_global_offset( | ||||||||
dtensor.shape, | ||||||||
mesh_shape=dtensor.device_mesh.shape, | ||||||||
my_coordinate=coordinates, | ||||||||
placements=dtensor.placements, | ||||||||
) | ||||||||
|
||||||||
return TensorSlice( | ||||||||
offsets=offsets, | ||||||||
coordinates=coordinates, | ||||||||
global_shape=dtensor.shape, | ||||||||
local_shape=dtensor._local_tensor.shape, | ||||||||
mesh_shape=dtensor.device_mesh.shape, | ||||||||
) | ||||||||
|
||||||||
|
||||||||
def reconstruct_dtensor_from_local_tensor( | ||||||||
local_tensor: torch.Tensor, | ||||||||
tensor_slice: "TensorSlice", | ||||||||
device_mesh: torch.distributed.DeviceMesh, | ||||||||
placements: Tuple[Placement, ...], | ||||||||
) -> DTensor: | ||||||||
""" | ||||||||
Reconstruct a DTensor from local tensor data and TensorSlice metadata. | ||||||||
Args: | ||||||||
local_tensor: The local tensor shard | ||||||||
tensor_slice: TensorSlice containing distributed metadata | ||||||||
device_mesh: The device mesh for the DTensor | ||||||||
placements: The placements for the DTensor | ||||||||
Returns: | ||||||||
Reconstructed DTensor | ||||||||
""" | ||||||||
return DTensor.from_local( | ||||||||
local_tensor=local_tensor, | ||||||||
device_mesh=device_mesh, | ||||||||
placements=placements, | ||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be testing implementation details as opposed to behaviors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.