Skip to content

Commit 2c5907b

Browse files
dstaay-fbmeta-codesync[bot]
authored andcommitted
RDMA Action API / basic functionality (#1410)
Summary: Pull Request resolved: #1410 Introducing RDMAAction, initial API class RDMAAction: def __init__(self) -> None def read_into(self, src: RDMABuffer, dst: LocalMemory | List[LocalMemory]) -> Self """Read from src RDMA buffer into dst memory.""" def write_from(self, dst: RDMABuffer, src: LocalMemory | List[LocalMemory]) -> Self """Write from src memory to dst RDMA buffer.""" def fetch_add(self, dst: RDMABuffer, src: LocalMemory, add: int) -> Self """Perform atomic fetch-and-add operation on dst RDMA buffer.""" def compare_and_swap(self, dst: RDMABuffer, src: LocalMemory, compare: int, swap: int) -> Self """Perform atomic compare-and-swap operation on dst RDMA buffer.""" def submit(self) -> Future[None] """Schedule and execute all batched RDMA operations."""` Key API Patterns ---------------- * **Fluent Interface**: All operation methods (`read_into`, `write_from`, etc.) return `Self` for method chaining * **Consistent Naming**: `src` = source, `dst` = destination * **Batch Execution**: Operations are queued via method calls, then executed together via `submit()` * **Async Support**: `submit()` returns a `Future[None]` for async/await usage; can call `submit()` multiple times to execute work multiple times. Implementation --------------- currently just leveraging python RdmaBuffer API, with some additional guards around data race + concurrency at QP level (see image), which leads to faster execution over multiple actors vs existing api can support. Lots of future improves possible, but want API out and will iterate. {F1982453335} Reviewed By: zdevito Differential Revision: D83782000 fbshipit-source-id: 7bd38c0f999d762ca62db2e36a79f08ee5b20e53
1 parent c175f21 commit 2c5907b

File tree

5 files changed

+566
-8
lines changed

5 files changed

+566
-8
lines changed

monarch_rdma/extension/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ impl PyRdmaBuffer {
260260
Ok(())
261261
})
262262
}
263+
264+
fn owner_actor_id(&self) -> String {
265+
self.owner_ref.actor_id().to_string()
266+
}
263267
}
264268

265269
#[pyclass(name = "_RdmaManager", module = "monarch._rust_bindings.rdma")]

python/monarch/_rust_bindings/rdma.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class _RdmaBuffer:
5353
timeout: int,
5454
) -> PythonTask[Any]: ...
5555
def size(self) -> int: ...
56+
def owner_actor_id(self) -> str: ...
5657
def __reduce__(self) -> tuple[Any, ...]: ...
5758
def __repr__(self) -> str: ...
5859
@staticmethod

python/monarch/_src/rdma/rdma.py

Lines changed: 221 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
import functools
1010
import logging
1111
import warnings
12-
from typing import cast, Optional
12+
from collections import defaultdict
13+
from typing import cast, List, Optional, Tuple
1314

1415
import torch
1516
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
17+
from typing_extensions import Self
1618

1719
try:
1820
from monarch._rust_bindings.rdma import _RdmaBuffer, _RdmaManager
1921
except ImportError as e:
2022
logging.error("RDMA is not available: {}".format(e))
2123
raise e
24+
from enum import Enum
2225
from typing import Dict
2326

2427
from monarch._src.actor.actor_mesh import Actor, context
@@ -335,7 +338,7 @@ async def write_from_nonblocking() -> None:
335338

336339
def drop(self) -> Future[None]:
337340
"""
338-
Release the handle on the memory that the remote holds to this memory.
341+
Release the handle on the memory that the src holds to this memory.
339342
"""
340343
local_proc_id = context().actor_instance.proc_id
341344
client = context().actor_instance
@@ -351,10 +354,221 @@ async def drop_nonblocking() -> None:
351354
return Future(coro=drop_nonblocking())
352355

353356
@property
354-
def owner(self) -> ProcMesh:
357+
def owner(self) -> str:
355358
"""
356-
The proc that owns this buffer
359+
The owner reference (str)
357360
"""
358-
# FIXME(slurye): Fix this once controller API is working properly
359-
# for v1.
360-
return cast(ProcMesh, context().actor_instance.proc)
361+
return self._buffer.owner_actor_id()
362+
363+
364+
LocalMemory = torch.Tensor | memoryview
365+
366+
367+
class RDMAAction:
368+
"""
369+
Schedule a bunch of actions at once. This provides an opportunity to
370+
optimize bulk RDMA transactions without exposing complexity to users.
371+
372+
"""
373+
374+
class RDMAOp(Enum):
375+
"""Enumeration of RDMA operation types."""
376+
377+
READ_INTO = "read_into"
378+
WRITE_FROM = "write_from"
379+
FETCH_ADD = "fetch_add"
380+
COMPARE_AND_SWAP = "compare_and_swap"
381+
382+
def __init__(self) -> None:
383+
self._instructs: List[Tuple[RDMAAction.RDMAOp, RDMABuffer, LocalMemory]] = []
384+
self._memory_dependencies: Dict[Tuple[int, int], RDMAAction.RDMAOp] = {}
385+
386+
def _check_and_merge_overlapping_range(
387+
self, addr: int, size: int, op: "RDMAAction.RDMAOp"
388+
) -> None:
389+
"""
390+
Check for overlapping ranges and merge if found.
391+
392+
Returns the final range to use (either new_range or expanded merged range).
393+
Updates self._memory_dependencies in place if merging occurs.
394+
"""
395+
new_start, new_end = addr, addr + size
396+
397+
# Find overlapping range
398+
overlapping_range = None
399+
for existing_start, existing_end in self._memory_dependencies:
400+
# Check if ranges overlap
401+
if not (new_end <= existing_start or existing_end <= new_start):
402+
overlapping_range = (existing_start, existing_end)
403+
break
404+
405+
# No overlap found - good to go
406+
if overlapping_range is None:
407+
self._memory_dependencies[(new_start, new_end)] = op
408+
return
409+
410+
# Overlap found - merge ranges
411+
existing_op = self._memory_dependencies[overlapping_range]
412+
413+
# Merge ops, only safe if neither is write_from at the moment
414+
if existing_op == self.RDMAOp.WRITE_FROM or op == self.RDMAOp.WRITE_FROM:
415+
raise ValueError(
416+
f"Same data range already has a write_from within RDMAAction: {existing_op} vs {op}"
417+
)
418+
419+
# Create expanded range that covers both
420+
expanded_range = (
421+
min(overlapping_range[0], new_start),
422+
max(overlapping_range[1], new_end),
423+
)
424+
425+
# range is unchanged - no need to update
426+
if expanded_range == (new_start, new_end):
427+
return
428+
429+
# Update dictionary: remove old range, add expanded range
430+
del self._memory_dependencies[overlapping_range]
431+
self._memory_dependencies[expanded_range] = op
432+
433+
# now since merged, possible need to merge again
434+
return self._check_and_merge_overlapping_range(
435+
expanded_range[0], expanded_range[1] - expanded_range[0], op
436+
)
437+
438+
def read_into(self, src: RDMABuffer, dst: LocalMemory | List[LocalMemory]) -> Self:
439+
"""
440+
Read from src RDMA buffer into dst memory.
441+
442+
Args:
443+
src: Source RDMA buffer to read from
444+
dst: Destination local memory to read into
445+
If dst is a list, it is the concatenation of the data in the list
446+
"""
447+
# Throw NotImplementedError for lists to simplify logic
448+
if isinstance(dst, list):
449+
raise NotImplementedError("List destinations not yet supported")
450+
451+
addr, size = _get_addr_and_size(dst)
452+
453+
if size < src.size():
454+
raise ValueError(
455+
f"dst memory size ({size}) must be >= src buffer size ({src.size()})"
456+
)
457+
458+
self._check_and_merge_overlapping_range(addr, size, self.RDMAOp.READ_INTO)
459+
460+
self._instructs.append((self.RDMAOp.READ_INTO, src, dst))
461+
462+
return self
463+
464+
def write_from(self, src: RDMABuffer, dst: LocalMemory | List[LocalMemory]) -> Self:
465+
"""
466+
Write from dst memory to src RDMA buffer.
467+
468+
Args:
469+
src: Destination RDMA buffer to write to
470+
dst: Source local memory to write from
471+
If local is a list, it is the concatenation of the data in the list
472+
"""
473+
# Throw NotImplementedError for lists to simplify logic
474+
if isinstance(dst, list):
475+
raise NotImplementedError("List sources not yet supported")
476+
477+
addr, size = _get_addr_and_size(dst)
478+
479+
if size > src.size():
480+
raise ValueError(
481+
f"Local memory size ({size}) must be <= src buffer size ({src.size()})"
482+
)
483+
484+
self._check_and_merge_overlapping_range(addr, size, self.RDMAOp.WRITE_FROM)
485+
486+
self._instructs.append((self.RDMAOp.WRITE_FROM, src, dst))
487+
488+
return self
489+
490+
def fetch_add(self, src: RDMABuffer, dst: LocalMemory, add: int) -> Self:
491+
"""
492+
Perform atomic fetch-and-add operation on src RDMA buffer.
493+
494+
Args:
495+
src: src RDMA buffer to perform operation on
496+
dst: Local memory to store the original value
497+
add: Value to add to the src buffer
498+
499+
Atomically:
500+
*dst = *src
501+
*src = *src + add
502+
503+
Note: src/dst are 8 bytes
504+
"""
505+
raise NotImplementedError("Not yet supported")
506+
507+
def compare_and_swap(
508+
self, src: RDMABuffer, dst: LocalMemory, compare: int, swap: int
509+
) -> Self:
510+
"""
511+
Perform atomic compare-and-swap operation on src RDMA buffer.
512+
513+
Args:
514+
src: src RDMA buffer to perform operation on
515+
dst: Local memory to store the original value
516+
compare: Value to compare against
517+
swap: Value to swap in if comparison succeeds
518+
519+
Atomically:
520+
*dst = *src;
521+
if (*src == compare) {
522+
*src = swap
523+
}
524+
525+
Note: src/dst are 8 bytes
526+
"""
527+
raise NotImplementedError("Not yet supported")
528+
529+
def submit(self) -> Future[None]:
530+
"""
531+
Schedules the work (can be called multiple times to schedule the same work more than once).
532+
Future completes when all the work is done.
533+
534+
Executes futures for each src actor independently and concurrently for optimal performance.
535+
"""
536+
537+
async def submit_all_work() -> None:
538+
if not self._instructs:
539+
return
540+
541+
work = defaultdict(list)
542+
543+
# Group operations by owner for concurrent execution per owner
544+
for op, src, dst in self._instructs:
545+
if op == self.RDMAOp.READ_INTO:
546+
fut = src.read_into(dst)
547+
elif op == self.RDMAOp.WRITE_FROM:
548+
fut = src.write_from(dst)
549+
else:
550+
raise NotImplementedError(f"Unknown RDMA operation: {op}")
551+
work[src.owner].append(fut)
552+
553+
# Create a list of tasks, one per owner, that wait for all that owner's futures sequentially
554+
owner_tasks = []
555+
556+
for _, futures in work.items():
557+
# Create a coroutine that processes all futures for a qp sequentially
558+
async def process_owner_futures(owner_futures_list=futures):
559+
"""Process all futures for a single qp sequentially"""
560+
for future in owner_futures_list:
561+
await future
562+
563+
# Convert to PythonTask for Monarch's native concurrency
564+
owner_task = PythonTask.from_coroutine(process_owner_futures())
565+
owner_tasks.append(owner_task)
566+
567+
# Spawn all owner tasks concurrently and collect their shared handles
568+
shared_tasks = [task.spawn() for task in owner_tasks]
569+
570+
# Wait for all owner tasks to complete concurrently
571+
for shared_task in shared_tasks:
572+
await shared_task
573+
574+
return Future(coro=submit_all_work())

python/monarch/rdma/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from monarch._src.rdma.rdma import (
1414
is_rdma_available,
15+
RDMAAction,
1516
RDMABuffer,
1617
RDMAReadTransferWarning,
1718
RDMAWriteTransferWarning,
@@ -20,6 +21,7 @@
2021
__all__ = [
2122
"is_rdma_available",
2223
"RDMABuffer",
24+
"RDMAAction",
2325
"RDMAReadTransferWarning",
2426
"RDMAWriteTransferWarning",
2527
]

0 commit comments

Comments
 (0)