-
Notifications
You must be signed in to change notification settings - Fork 5
[WIP] PTD / Gloo transport implementation #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
6f538a8
cf45eb3
7ef05ff
bac48d2
9240af6
fa07691
d3bf6be
30f616f
ff9ed7c
2cd04dd
ede981c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,10 +12,11 @@ | |||||
from monarch.actor import Actor, endpoint | ||||||
|
||||||
from torchstore.transport.buffers import TransportBuffer | ||||||
|
||||||
from torchstore.transport.pipe import Request, TensorSlice | ||||||
|
||||||
from torchstore.utils import assemble_global_tensor, spawn_actors | ||||||
|
||||||
|
||||||
logger = getLogger(__name__) | ||||||
|
||||||
|
||||||
|
@@ -31,8 +32,10 @@ def __init__( | |||||
self, | ||||||
id_func, | ||||||
) -> None: | ||||||
init_logging() | ||||||
self.store: StorageImpl = InMemoryStore() | ||||||
self.volume_id: str = id_func() | ||||||
self.transport_context = {} | ||||||
|
||||||
@classmethod | ||||||
async def spawn( | ||||||
|
@@ -56,12 +59,19 @@ async def get_id(self) -> str: | |||||
async def put( | ||||||
self, key: str, transport_buffer: TransportBuffer, request: Request | ||||||
) -> None: | ||||||
# something like | ||||||
# transport_buffer.set_context(self.transport_context) | ||||||
transport_buffer.transport_context = self.transport_context | ||||||
transport_buffer.remote_rank = 0 | ||||||
await self.store.put(key, transport_buffer, request) | ||||||
|
||||||
@endpoint | ||||||
async def get( | ||||||
self, key: str, transport_buffer: TransportBuffer, request: Request | ||||||
) -> TransportBuffer: | ||||||
# transport_buffer.set_context(self.transport_context) | ||||||
transport_buffer.transport_context = self.transport_context | ||||||
transport_buffer.remote_rank = 0 | ||||||
return await self.store.get(key, transport_buffer, request) | ||||||
|
||||||
@endpoint | ||||||
|
@@ -72,9 +82,16 @@ async def get_meta( | |||||
) -> Union[Tuple[torch.Size, torch.dtype], str]: | ||||||
return await self.store.get_meta(key, request) | ||||||
|
||||||
@endpoint | ||||||
async def delete(self, key: str) -> None: | ||||||
await self.store.delete(key) | ||||||
|
||||||
@endpoint | ||||||
async def setup_comms(self, transport_buffer) -> None: | ||||||
logger.info("Initiating handshake on volume side") | ||||||
await transport_buffer.storage_volume_setup_comms(self.transport_context) | ||||||
logger.info("Finished initiating handshake on volume side") | ||||||
|
||||||
|
||||||
class StorageImpl: | ||||||
"""Abstract base class for storage implementations.""" | ||||||
|
@@ -194,6 +211,7 @@ async def put( | |||||
# since we pass tensor=None to the transport buffer, | ||||||
# we allocate on the fly | ||||||
tensor = await transport_buffer.read_into(tensor=None) | ||||||
transport_buffer.finish() | ||||||
|
||||||
if request.tensor_slice is not None: | ||||||
self._handle_dtensor(key, request.tensor_slice, tensor) | ||||||
return | ||||||
|
@@ -216,6 +234,7 @@ async def get( | |||||
|
||||||
if request.tensor_slice is None: | ||||||
await transport_buffer.write_from(self.kv[key]) | ||||||
transport_buffer.finish() | ||||||
LucasLLC marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
casteryh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
return transport_buffer | ||||||
|
||||||
# TODO: | ||||||
|
@@ -227,6 +246,7 @@ async def get( | |||||
for shard in self.kv[key].values(): | ||||||
if shard["slice"] == request.tensor_slice: | ||||||
await transport_buffer.write_from(shard["tensor"]) | ||||||
transport_buffer.finish() | ||||||
LucasLLC marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
casteryh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
transport_buffer.finish() | |
await transport_buffer.finish() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,16 @@ def update(self, other_buffer: "TransportBuffer") -> None: | |
self.objects = other_buffer.objects | ||
self.requires_meta = other_buffer.requires_meta | ||
|
||
async def setup_comms(self, storage_volume) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this safe to be called concurrently and idempotent? Based on how you use it in the create transport buffer code, I assume it is more like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo. We may need a lock based on the client, wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not in the scope of this PR but we can add a TODO and add an issue just to keep track of this. |
||
"""Initiate comms handshake with storage_volume""" | ||
pass | ||
|
||
async def storage_volume_setup_comms( | ||
self, transport_context: Dict[str, Any] | ||
) -> None: | ||
"""Mirror of setup_comms, but run on the storage volume side""" | ||
raise NotImplementedError("Must implement storage_volume_setup_comms") | ||
|
||
def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: | ||
"""Allocates internal buffers based on either an existing tensor | ||
or a Tuple of (shape, dtype) | ||
|
@@ -70,6 +80,10 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: | |
async def write_from(self, tensor: Optional[torch.Tensor]) -> None: | ||
raise NotImplementedError() | ||
|
||
def finish(self) -> None: | ||
casteryh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
casteryh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Finalize the transport buffer""" | ||
pass | ||
|
||
|
||
class RDMATransportBuffer(TransportBuffer): | ||
# TODO: when we try this with rdma, I should be able to write rdma directly to the tensor | ||
|
Uh oh!
There was an error while loading. Please reload this page.