Skip to content
15 changes: 8 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,33 @@ async def do_get(self):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
async def test_basic(strategy_params, transport_type):
# FSDP
put_mesh_shape = (1,)
get_mesh_shape = (1,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_resharding(strategy_params, use_rdma):
async def test_resharding(strategy_params, transport_type):
# FSDP
put_mesh_shape = (4,)
get_mesh_shape = (8,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)


async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
async def _do_test(put_mesh_shape, get_mesh_shape, strategy, transport_type):

ts.init_logging()
logger.info(f"Testing with strategy: {strategy}")

put_world_size = math.prod(put_mesh_shape)
await ts.initialize(
num_storage_volumes=put_world_size if strategy is not None else 1,
strategy=strategy,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)
try:
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
36 changes: 18 additions & 18 deletions tests/test_resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import os
import tempfile
from logging import getLogger
from typing import List, Tuple, Union
from typing import List, Tuple, Type, Union

import pytest

import torch

import torchstore as ts

from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
from torchstore.transport import TransportType
from torchstore.utils import get_local_tensor, spawn_actors

from .utils import DTensorActor, main, transport_plus_strategy_params
Expand All @@ -27,7 +27,7 @@

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_resharding(strategy_params, use_rdma):
async def test_1d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

for put_mesh_shape, get_mesh_shape in [
Expand All @@ -47,13 +47,13 @@ async def test_1d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(get_sharding_dim)],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
async def test_2d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = get_mesh_shape = (2, 2)
Expand All @@ -69,13 +69,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_to_2d_resharding(strategy_params, use_rdma):
async def test_1d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (4,)
Expand All @@ -92,13 +92,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_1d_resharding(strategy_params, use_rdma):
async def test_2d_to_1d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (2, 2)
Expand All @@ -115,13 +115,13 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_data_parallel(strategy_params, use_rdma):
async def test_data_parallel(strategy_params, transport_type):
_, strategy = strategy_params

# # 1d
Expand All @@ -134,7 +134,7 @@ async def test_data_parallel(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=placements,
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)

# 2d -> 1d
Expand All @@ -149,7 +149,7 @@ async def test_data_parallel(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(1)],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


Expand All @@ -158,8 +158,8 @@ async def _test_resharding(
put_placements: List[Union[Replicate, Shard]],
get_mesh_shape: Tuple[int],
get_placements: List[Union[Replicate, Shard]],
strategy: ts.TorchStoreStrategy,
use_rdma: bool,
strategy: Type[ts.TorchStoreStrategy],
transport_type: TransportType,
):
"""Given a "put" mesh shape and a "get" mesh shape.
1. Create separate worlds for each mesh shape, running on different devices /PGs.
Expand All @@ -183,8 +183,6 @@ async def _test_resharding(

# Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3]
"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

put_world_size = math.prod(put_mesh_shape)
get_world_size = math.prod(get_mesh_shape)
assert (
Expand All @@ -206,7 +204,9 @@ async def _test_resharding(
) # 8x8 square, with ([[0...7],[8...15],[...]])
await ts.initialize(
num_storage_volumes=put_world_size if strategy is not None else 1,
strategy=strategy,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)
with tempfile.TemporaryDirectory() as filesystem_store_dir:
# each actor mesh represents a group of processes.
Expand Down
14 changes: 8 additions & 6 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ async def do_get(self):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_state_dict(strategy_params, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

async def test_state_dict(strategy_params, transport_type):
class Trainer(Actor):
# Monarch RDMA does not work outside of an actor, so we need
# to wrapp this test first
Expand Down Expand Up @@ -200,7 +198,12 @@ async def do_test(self):
return state_dict, fetched_state_dict

_, strategy = strategy_params
await ts.initialize(num_storage_volumes=1, strategy=strategy)
await ts.initialize(
num_storage_volumes=1,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)
trainer = await spawn_actors(1, Trainer, "trainer")
try:
state_dict, fetched_state_dict = await trainer.do_test.call_one()
Expand All @@ -212,8 +215,7 @@ async def do_test(self):
@pytest.mark.skip("TODO(kaiyuan-li@): fix this test")
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_dcp_sharding_parity(strategy_params, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
async def test_dcp_sharding_parity(strategy_params, transport_type):

for save_mesh_shape, get_mesh_shape in [
((2,), (4,)),
Expand Down
52 changes: 36 additions & 16 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
async def test_basic(strategy_params, transport_type):
"""Test basic put/get functionality for multiple processes"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class PutGetActor(Actor):
"""Each instance of this actor represents a single process."""
Expand All @@ -60,7 +59,12 @@ async def get(self, rank_offset=0):
return await ts.get(f"key_{other_rank}")

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
await ts.initialize(
num_storage_volumes=volume_world_size,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)
# each actor mesh represents a group of processes.
actor_mesh_0 = await spawn_actors(
volume_world_size, PutGetActor, "actor_mesh_0", world_size=volume_world_size
Expand Down Expand Up @@ -91,9 +95,8 @@ async def get(self, rank_offset=0):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_objects(strategy_params, use_rdma):
async def test_objects(strategy_params, transport_type):
"""Test put/get on arbitrary object"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class ObjectActor(Actor):
"""Each instance of this actor represents a single process."""
Expand All @@ -118,7 +121,12 @@ async def get(self, rank_offset=0):
return await ts.get(f"key_{other_rank}")

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
await ts.initialize(
num_storage_volumes=volume_world_size,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)
# each actor mesh represents a group of processes.
actor_mesh_0 = await spawn_actors(
volume_world_size, ObjectActor, "actor_mesh_0", world_size=volume_world_size
Expand Down Expand Up @@ -154,9 +162,8 @@ def __eq__(self, other: object) -> bool:

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_exists(strategy_params, use_rdma):
async def test_exists(strategy_params, transport_type):
"""Test the exists() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class ExistsTestActor(Actor):
"""Actor for testing exists functionality."""
Expand All @@ -177,7 +184,12 @@ async def exists(self, key):
return await ts.exists(key)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
await ts.initialize(
num_storage_volumes=volume_world_size,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)

# Spawn test actors
actor_mesh = await spawn_actors(
Expand Down Expand Up @@ -222,9 +234,8 @@ async def exists(self, key):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_delete(strategy_params, use_rdma):
async def test_delete(strategy_params, transport_type):
"""Test the delete() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class DeleteTestActor(Actor):
"""Actor for testing delete functionality."""
Expand Down Expand Up @@ -253,7 +264,12 @@ async def get(self, key):
return await ts.get(key)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
await ts.initialize(
num_storage_volumes=volume_world_size,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)

# Spawn test actors
actor_mesh = await spawn_actors(
Expand Down Expand Up @@ -303,9 +319,8 @@ async def get(self, key):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_get_tensor_slice(strategy_params, use_rdma):
async def test_get_tensor_slice(strategy_params, transport_type):
"""Test tensor slice API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class TensorSlicePutActor(Actor):
"""Actor for putting tensors."""
Expand All @@ -322,7 +337,12 @@ async def put(self, key, tensor):
await ts.put(key, tensor)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
await ts.initialize(
num_storage_volumes=volume_world_size,
strategy=strategy(transport_type=transport_type)
if strategy is not None
else None,
)

# Spawn test actors - separate meshes for put and get to test cross-process communication
put_actor_mesh = await spawn_actors(
Expand Down Expand Up @@ -405,7 +425,7 @@ class LargeTensorActor(Actor):
step_size: int = 100 # -> 400mb
max_step: int = 600 # 4mb -> 2gb

def __init__(self, generate_benchmark=False) -> None:
def __init__(self, generate_benchmark=True) -> None:
self.generate_benchmark = generate_benchmark
init_logging()

Expand Down
18 changes: 10 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed._tensor import distribute_tensor
from torch.distributed.device_mesh import init_device_mesh
from torchstore.transport import TransportType
from torchstore.transport.buffers import monarch_rdma_available

logger = getLogger(__name__)

Expand All @@ -26,17 +28,17 @@ def main(file):

def transport_plus_strategy_params():
strategies = [
(2, ts.LocalRankStrategy()),
(2, ts.LocalRankStrategy),
(1, None), # ts.SingletonStrategy
(1, ts.ControllerStorageVolumes()),
(1, ts.ControllerStorageVolumes),
]
rdma_options = (
[True, False]
if os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1"
else [False]
)

return "strategy_params, use_rdma", list(product(strategies, rdma_options))
transport_types = list(TransportType)
if not monarch_rdma_available():
print("Removing rdma tests since rdma is not available")
transport_types.remove(TransportType.MonarchRDMA)

return "strategy_params, transport_type", list(product(strategies, transport_types))


class DTensorActor(Actor):
Expand Down
Loading