Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions get_benchmark.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
size_mbytes, delta
4, 0.018769782967865467, 213.10848435744523
404, 0.15033994475379586, 2687.2432383928995
804, 0.4327937951311469, 1857.6976126849709
1204, 0.9795559397898614, 1229.1283744941481
1604, 0.8066510939970613, 1988.4681393686224
2004, 0.8627498750574887, 2322.8053204487164
7 changes: 7 additions & 0 deletions put_benchmark.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
size_mbytes, delta
4, 0.6578615619800985, 6.080306604265484
404, 0.16606504004448652, 2432.7817576280595
804, 0.3294775849208236, 2440.226700681955
1204, 0.4480641600675881, 2687.115166315429
1604, 0.5693950089626014, 2817.025043690456
2004, 0.6524990671314299, 3071.268758758767
7 changes: 3 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torchstore.logging import init_logging
from torchstore.utils import spawn_actors
from torchstore.state_dict_utils import _state_dict_size
from torchstore.utils import spawn_actors

from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -170,13 +169,13 @@ async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
file_store_name=os.path.join(tmpdir, "get_world"),
)

logger.info(f"do_push ")
logger.info("do_push ")
await put_world.do_push.call()


await get_world.do_get.call()
finally:
await ts.shutdown()


if __name__ == "__main__":
main([__file__])
1 change: 1 addition & 0 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def do_test(self):
_assert_equal_state_dict(state_dict, fetched_state_dict)


@pytest.mark.skip("TODO(kaiyuan-li@): fix this test")
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_dcp_sharding_parity(strategy_params, use_rdma):
Expand Down
101 changes: 93 additions & 8 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __eq__(self, other: object) -> bool:

try:
for idx in range(volume_world_size):
actor = actor_mesh_0.slice(**{"hosts": 0, "gpus": idx})
actor = actor_mesh_0.slice(gpus=idx)
await actor.put.call(MyTestObject(idx))

for rank_offset in (0, 1):
Expand Down Expand Up @@ -196,7 +196,7 @@ async def exists(self, key):
# Test 2: Store tensors and check existence
tensor = torch.tensor([1, 2, 3, 4, 5])
for rank in range(volume_world_size):
actor = actor_mesh.slice(**{"hosts": 0, "gpus": rank})
actor = actor_mesh.slice(gpus=rank)
await actor.put.call(f"tensor_key_{rank}", tensor)

for rank in range(volume_world_size):
Expand All @@ -207,7 +207,7 @@ async def exists(self, key):
# Test 3: Store objects and check existence
obj = {"rank": 0, "data": [1, 2, 3]}
for rank in range(volume_world_size):
actor = actor_mesh.slice(**{"hosts": 0, "gpus": rank})
actor = actor_mesh.slice(gpus=rank)
await actor.put.call(f"object_key_{rank}", obj)

for rank in range(volume_world_size):
Expand All @@ -220,6 +220,87 @@ async def exists(self, key):
await ts.shutdown()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_delete(strategy_params, use_rdma):
"""Test the delete() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class DeleteTestActor(Actor):
"""Actor for testing delete functionality."""

def __init__(self, world_size):
init_logging()
self.world_size = world_size
self.rank = current_rank().rank
# required by LocalRankStrategy
os.environ["LOCAL_RANK"] = str(self.rank)

@endpoint
async def put(self, key, value):
await ts.put(key, value)

@endpoint
async def delete(self, key):
await ts.delete(key)

@endpoint
async def exists(self, key):
return await ts.exists(key)

@endpoint
async def get(self, key):
return await ts.get(key)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)

# Spawn test actors
actor_mesh = await spawn_actors(
volume_world_size,
DeleteTestActor,
"delete_test_actors",
world_size=volume_world_size,
)

try:
# Test 1: Store tensors, verify they exist, then delete them
tensor = torch.tensor([1, 2, 3, 4, 5])
for rank in range(volume_world_size):
actor = actor_mesh.slice(gpus=rank)
await actor.put.call(f"tensor_key_{rank}", tensor)

# Verify all tensors exist
for rank in range(volume_world_size):
results = await actor_mesh.exists.call(f"tensor_key_{rank}")
for _, exists_result in results:
assert exists_result

# Delete tensors one at a time and verify each deletion
for rank in range(volume_world_size):
actor = actor_mesh.slice(gpus=rank)
await actor.delete.call(f"tensor_key_{rank}")

# Verify this specific tensor no longer exists
results = await actor_mesh.exists.call(f"tensor_key_{rank}")
for _, exists_result in results:
assert not exists_result

# Verify other tensors still exist (if any remain)
for other_rank in range(rank + 1, volume_world_size):
results = await actor_mesh.exists.call(f"tensor_key_{other_rank}")
for _, exists_result in results:
assert exists_result

# Test 2: Try to get deleted tensor (should raise exception)
with pytest.raises(Exception):
await actor_mesh.get.call("tensor_key_0")

finally:
await actor_mesh._proc_mesh.stop()
await ts.shutdown()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_get_tensor_slice(strategy_params, use_rdma):
Expand Down Expand Up @@ -256,7 +337,7 @@ async def put(self, key, tensor):
key = "test_tensor"

# Store the tensor using put actor mesh
put_actor = put_actor_mesh.slice(**{"hosts": 0, "gpus": 0})
put_actor = put_actor_mesh.slice(gpus=0)
await put_actor.put.call(key, test_tensor)

# Test full tensor retrieval using get actor mesh
Expand Down Expand Up @@ -324,7 +405,7 @@ class LargeTensorActor(Actor):
step_size: int = 100 # -> 400mb
max_step: int = 600 # 4mb -> 2gb

def __init__(self, generate_benchmark=False) -> None:
def __init__(self, generate_benchmark=True) -> None:
self.generate_benchmark = generate_benchmark
init_logging()

Expand Down Expand Up @@ -386,9 +467,13 @@ async def get(self):
# controller code
await ts.initialize()
actor = await spawn_actors(1, LargeTensorActor, "large_tensor")
await actor.put.call_one()
await actor.get.call_one()
# TODO: assert equal tensors from put/get
try:
await actor.put.call_one()
await actor.get.call_one()
# TODO: assert equal tensors from put/get
finally:
await actor._proc_mesh.stop()
await ts.shutdown()


@pytest.mark.asyncio
Expand Down
3 changes: 3 additions & 0 deletions torchstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# this helps with this
import torch

import os
from logging import getLogger

Expand Down
39 changes: 33 additions & 6 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time
import asyncio
from logging import getLogger
from typing import Any, Union

import torch
from torch.distributed.tensor import DTensor

from torchstore.controller import ObjectType
from torchstore.transport import Pipe, Request, TensorSlice
from torchstore.controller import ObjectType
from torchstore.logging import LatencyTracker
from torchstore.transport import Pipe, Request
from torchstore.transport import Pipe, Request, TensorSlice
from torchstore.utils import assemble_global_tensor, get_local_tensor

logger = getLogger(__name__)
Expand Down Expand Up @@ -53,7 +51,6 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]):
latency_tracker.track_step("notify_put")
latency_tracker.track_e2e()


@torch.no_grad
async def get(
self,
Expand Down Expand Up @@ -122,6 +119,35 @@ async def keys(self, prefix: str | None = None) -> list[str]:
# Keys are synced across all storage volumes, so we just call one.
return await self._controller.keys.call_one(prefix)

async def delete(self, key: str) -> None:
"""
Delete a key from the distributed store.

Args:
key (str): The key to delete.

Returns:
None

Raises:
KeyError: If the key does not exist in the store.
"""
latency_tracker = LatencyTracker(f"delete:{key}")
volume_map = await self._controller.locate_volumes.call_one(key)

async def delete_from_volume(volume_id: str):
volume = self.strategy.get_storage_volume(volume_id)
# Notify should come before the actual delete, so that the controller
# doesn't think the key is still in the store when delete is happening.
await self._controller.notify_delete.call_one(key, volume_id)
await volume.delete.call(key)

await asyncio.gather(
*[delete_from_volume(volume_id) for volume_id in volume_map]
)

latency_tracker.track_e2e()

async def exists(self, key: str) -> bool:
"""Check if a key exists in the distributed store.

Expand Down Expand Up @@ -179,7 +205,8 @@ def _verify_get_args(
and tensor_slice_spec.local_shape != inplace_tensor.shape
):
raise ValueError(
f"Requested tensor slice shape {tensor_slice_spec.local_shape} does not match in-place tensor shape {inplace_tensor.shape}"
f"Requested tensor slice shape {tensor_slice_spec.local_shape} "
f"does not match in-place tensor shape {inplace_tensor.shape}"
)

if isinstance(inplace_tensor, DTensor):
Expand Down
1 change: 0 additions & 1 deletion torchstore/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
import sys


def init_logging():
log_level = os.environ.get("TORCHSTORE_LOG_LEVEL", "INFO").upper()

Expand Down
17 changes: 0 additions & 17 deletions torchstore/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,6 @@ async def get_state_dict(
inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None,
)

# # Prepare all the coroutines first
# coros = []
# keys = []
# for flattened_key in fetched_mapping.keys():
# inplace_tensor = user_flattened_state_dict.get(flattened_key, None)
# keys.append(flattened_key)
# coros.append(
# store.get(
# f"{key}{DELIM}{flattened_key}",
# inplace_tensor if isinstance(inplace_tensor, torch.Tensor) else None,
# )
# )
# # Run all requests concurrently
# results = await asyncio.gather(*coros)
# # Build the result dictionary
# fetched_state_dict = dict(zip(keys, results))

return unflatten_state_dict(fetched_state_dict, fetched_mapping)

def _state_dict_size(state_dict):
Expand Down
Loading