Skip to content
2 changes: 1 addition & 1 deletion tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class LargeTensorActor(Actor):
step_size: int = 100 # -> 400mb
max_step: int = 600 # 4mb -> 2gb

def __init__(self, generate_benchmark=False) -> None:
def __init__(self, generate_benchmark=True) -> None:
self.generate_benchmark = generate_benchmark
init_logging()

Expand Down
15 changes: 8 additions & 7 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
from torch.distributed.tensor import DTensor

from torchstore.controller import ObjectType
from torchstore.controller import Controller, ObjectType
from torchstore.logging import LatencyTracker
from torchstore.strategy import TorchStoreStrategy
from torchstore.transport import Pipe, Request, TensorSlice
from torchstore.utils import assemble_global_tensor, get_local_tensor

Expand All @@ -24,11 +25,7 @@ class LocalClient:
is handled by the client.
"""

def __init__(
self,
controller,
strategy,
):
def __init__(self, controller: Controller, strategy: TorchStoreStrategy):
self._controller = controller
self.strategy = strategy

Expand All @@ -41,7 +38,6 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]):
# it will never be dynamic. e.g. it's always based on the
# TorchstoreStrategy defined during intiailization
storage_volume, volume_id = self.strategy.select_storage_volume()

pipe = Pipe(storage_volume)

await pipe.put_to_storage_volume(key, request)
Expand Down Expand Up @@ -220,14 +216,17 @@ async def _get_stored_object_type(self, key: str) -> ObjectType | None:
volume_map = await self._controller.locate_volumes.call_one(key)
for storage_info in volume_map.values():
return storage_info.object_type

raise ValueError(f"Unable to get stored object type for key `{key}`")

async def _get_object(self, key: str):
volume_map = await self._controller.locate_volumes.call_one(key)
volume_id, _ = volume_map.popitem()

storage_volume = self.strategy.get_storage_volume(volume_id)
pipe = Pipe(storage_volume)
request = Request.from_any(None)

return await pipe.get_from_storage_volume(key, request)

async def _get_tensor(self, key: str) -> torch.Tensor:
Expand All @@ -243,6 +242,8 @@ async def _get_tensor(self, key: str) -> torch.Tensor:
request = Request.from_any(None)
return await pipe.get_from_storage_volume(key, request)

raise ValueError(f"Unable to get tensor for key `{key}`")

async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor:
"""Fetches slices from all volume storages and stitch together to return the whole tensor"""

Expand Down
9 changes: 5 additions & 4 deletions torchstore/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Dict, List, Mapping, Optional, Set
from typing import Dict, List, Mapping, Optional, Set, TYPE_CHECKING

from monarch.actor import Actor, endpoint

from torchstore.storage_utils.trie import Trie
from torchstore.storage_volume import StorageVolume
from torchstore.strategy import TorchStoreStrategy
from torchstore.transport.pipe import Request, TensorSlice

if TYPE_CHECKING:
from torchstore.storage_volume import StorageVolume

# TODO: move this into request as a field
class ObjectType(Enum):
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
self.keys_to_storage_volumes = Trie()
self.is_initialized: bool = False
self.strategy: Optional[TorchStoreStrategy] = None
self.storage_volumes: Optional[StorageVolume] = None
self.storage_volumes: Optional["StorageVolume"] = None
self.num_storage_volumes: Optional[int] = None
self.strategy: Optional[TorchStoreStrategy] = None

Expand All @@ -66,7 +67,7 @@ async def init(
self,
strategy: TorchStoreStrategy,
num_storage_volumes: int,
storage_volumes: StorageVolume,
storage_volumes: "StorageVolume",
) -> None:
if self.is_initialized:
raise RuntimeError("TorchStore is already initialized")
Expand Down
23 changes: 21 additions & 2 deletions torchstore/storage_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from monarch.actor import Actor, endpoint

from torchstore.transport.buffers import TransportBuffer

from torchstore.transport.pipe import Request, TensorSlice

from torchstore.utils import assemble_global_tensor, spawn_actors


logger = getLogger(__name__)


Expand All @@ -31,8 +32,10 @@ def __init__(
self,
id_func,
) -> None:
init_logging()
self.store: StorageImpl = InMemoryStore()
self.volume_id: str = id_func()
self.transport_context = {}

@classmethod
async def spawn(
Expand All @@ -56,12 +59,19 @@ async def get_id(self) -> str:
async def put(
self, key: str, transport_buffer: TransportBuffer, request: Request
) -> None:
# something like
# transport_buffer.set_context(self.transport_context)
transport_buffer.transport_context = self.transport_context
transport_buffer.remote_rank = 0
await self.store.put(key, transport_buffer, request)

@endpoint
async def get(
self, key: str, transport_buffer: TransportBuffer, request: Request
) -> TransportBuffer:
# transport_buffer.set_context(self.transport_context)
transport_buffer.transport_context = self.transport_context
transport_buffer.remote_rank = 0
return await self.store.get(key, transport_buffer, request)

@endpoint
Expand All @@ -72,9 +82,16 @@ async def get_meta(
) -> Union[Tuple[torch.Size, torch.dtype], str]:
return await self.store.get_meta(key, request)

@endpoint
async def delete(self, key: str) -> None:
await self.store.delete(key)

@endpoint
async def setup_comms(self, transport_buffer) -> None:
logger.info("Initiating handshake on volume side")
await transport_buffer.storage_volume_setup_comms(self.transport_context)
logger.info("Finished initiating handshake on volume side")


class StorageImpl:
"""Abstract base class for storage implementations."""
Expand Down Expand Up @@ -194,6 +211,7 @@ async def put(
# since we pass tensor=None to the transport buffer,
# we allocate on the fly
tensor = await transport_buffer.read_into(tensor=None)
transport_buffer.finish()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see previous comments. don't block main thread here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would happen when you call finish() if the buffer is not a gloo buffer by the way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finish is a noop outside of ptd buffer

if request.tensor_slice is not None:
self._handle_dtensor(key, request.tensor_slice, tensor)
return
Expand All @@ -216,6 +234,7 @@ async def get(

if request.tensor_slice is None:
await transport_buffer.write_from(self.kv[key])
transport_buffer.finish()
return transport_buffer

# TODO:
Expand All @@ -227,6 +246,7 @@ async def get(
for shard in self.kv[key].values():
if shard["slice"] == request.tensor_slice:
await transport_buffer.write_from(shard["tensor"])
transport_buffer.finish()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
transport_buffer.finish()
await transport_buffer.finish()

return transport_buffer

raise RuntimeError(f"Tensor slice {request.tensor_slice} not found in {key}")
Expand Down Expand Up @@ -264,7 +284,6 @@ async def get_meta(
f"Could not find shard slice with {request.tensor_slice=} Slices:{stored_object}"
)

raise RuntimeError(f"Unknown type for {key} type={type(val)} {val=}")
raise RuntimeError(f"Unknown type for {key} type={type(val)}")

async def delete(self, key: str) -> None:
Expand Down
48 changes: 42 additions & 6 deletions torchstore/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,34 @@
"""

import os
from typing import Any, Dict, TYPE_CHECKING

from monarch.actor import current_rank

from torchstore.storage_volume import StorageVolume
from torchstore.transport import TransportType

if TYPE_CHECKING:
from torchstore.storage_volume import StorageVolume


class StorageVolumeRef:
def __init__(
self,
volume: "StorageVolume",
volume_id: str,
transport_type: "TransportType",
transport_context: Dict,
):
self.volume = volume
self.volume_id = volume_id
self.transport_type = transport_type
# useful for caching transport objects that should survive the lifetime of the client/volume
self.transport_context = transport_context

def __getattr__(self, name: str) -> Any:
if name not in ("volume_id", "transport_type"):
return getattr(self.volume, name)
return super().__getattribute__(name)


class TorchStoreStrategy:
Expand All @@ -28,9 +52,11 @@ class TorchStoreStrategy:
Subclasses must implement get_volume_id() and get_client_id() methods.
"""

def __init__(self):
def __init__(self, transport_type: TransportType = TransportType.TorchDistributed):
self.storage_volumes = None
self.volume_id_to_coord = {}
self.transport_type = transport_type
self.transport_context = {}

def __str__(self) -> str:
storage_vol_len = (
Expand Down Expand Up @@ -84,7 +110,7 @@ def select_storage_volume(self):
client_id,
) # client_id == volume_id for this strategy

def get_storage_volume(self, volume_id: str) -> StorageVolume:
def get_storage_volume(self, volume_id: str) -> StorageVolumeRef:
"""Retrieves storage volume actor for a given volume ID.

Args:
Expand All @@ -94,7 +120,12 @@ def get_storage_volume(self, volume_id: str) -> StorageVolume:
StorageVolume: The storage volume actor for the given ID.
"""
volume_coord = self.volume_id_to_coord[volume_id]
return self.storage_volumes.slice(**volume_coord)
return StorageVolumeRef(
self.storage_volumes.slice(**volume_coord),
volume_id,
self.transport_type,
self.transport_context,
)


class SingletonStrategy(TorchStoreStrategy):
Expand Down Expand Up @@ -205,5 +236,10 @@ def select_storage_volume(self):
client_id,
) # client_id == volume_id for this strategy

def get_storage_volume(self, volume_id: str) -> StorageVolume:
return self.storage_volumes
def get_storage_volume(self, volume_id: str) -> StorageVolumeRef:
return StorageVolumeRef(
self.storage_volumes,
volume_id,
self.transport_type,
self.transport_context,
)
21 changes: 20 additions & 1 deletion torchstore/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import auto, Enum

from torchstore.transport.buffers import MonarchTransportBuffer, RDMATransportBuffer

from torchstore.transport.pipe import Pipe, Request, TensorSlice
from torchstore.transport.torch_distributed_buffer import TorchDistributedBuffer


class TransportType(Enum):
MonarchRPC = auto()
MonarchRDMA = auto()
TorchDistributed = auto()

def buffer_cls(self):
return {
TransportType.MonarchRPC: MonarchTransportBuffer,
TransportType.MonarchRDMA: RDMATransportBuffer,
TransportType.TorchDistributed: TorchDistributedBuffer,
}[self]


__all__ = ["Pipe", "Request", "TensorSlice"]
__all__ = ["Pipe", "Request", "TensorSlice", "TransportType"]
14 changes: 14 additions & 0 deletions torchstore/transport/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def update(self, other_buffer: "TransportBuffer") -> None:
self.objects = other_buffer.objects
self.requires_meta = other_buffer.requires_meta

async def setup_comms(self, storage_volume) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe to be called concurrently and idempotent? Based on how you use it in the create transport buffer code, I assume it is more like ensure_comms ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.

We may need a lock based on the client, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.

We may need a lock based on the client, wdyt?

Not in the scope of this PR but we can add a TODO and add an issue just to keep track of this.

"""Initiate comms handshake with storage_volume"""
pass

async def storage_volume_setup_comms(
self, transport_context: Dict[str, Any]
) -> None:
"""Mirror of setup_comms, but run on the storage volume side"""
raise NotImplementedError("Must implement storage_volume_setup_comms")

def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None:
"""Allocates internal buffers based on either an existing tensor
or a Tuple of (shape, dtype)
Expand All @@ -70,6 +80,10 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor:
async def write_from(self, tensor: Optional[torch.Tensor]) -> None:
raise NotImplementedError()

def finish(self) -> None:
"""Finalize the transport buffer"""
pass


class RDMATransportBuffer(TransportBuffer):
# TODO: when we try this with rdma, I should be able to write rdma directly to the tensor
Expand Down
Loading
Loading