9
9
import functools
10
10
import logging
11
11
import warnings
12
- from typing import cast , Optional
12
+ from collections import defaultdict
13
+ from typing import cast , List , Optional , Tuple
13
14
14
15
import torch
15
16
from monarch ._rust_bindings .monarch_hyperactor .pytokio import PythonTask , Shared
17
+ from typing_extensions import Self
16
18
17
19
try :
18
20
from monarch ._rust_bindings .rdma import _RdmaBuffer , _RdmaManager
19
21
except ImportError as e :
20
22
logging .error ("RDMA is not available: {}" .format (e ))
21
23
raise e
24
+ from enum import Enum
22
25
from typing import Dict
23
26
24
27
from monarch ._src .actor .actor_mesh import Actor , context
@@ -335,7 +338,7 @@ async def write_from_nonblocking() -> None:
335
338
336
339
def drop (self ) -> Future [None ]:
337
340
"""
338
- Release the handle on the memory that the remote holds to this memory.
341
+ Release the handle on the memory that the src holds to this memory.
339
342
"""
340
343
local_proc_id = context ().actor_instance .proc_id
341
344
client = context ().actor_instance
@@ -351,10 +354,221 @@ async def drop_nonblocking() -> None:
351
354
return Future (coro = drop_nonblocking ())
352
355
353
356
@property
354
- def owner (self ) -> ProcMesh :
357
+ def owner (self ) -> str :
355
358
"""
356
- The proc that owns this buffer
359
+ The owner reference (str)
357
360
"""
358
- # FIXME(slurye): Fix this once controller API is working properly
359
- # for v1.
360
- return cast (ProcMesh , context ().actor_instance .proc )
361
+ return self ._buffer .owner_actor_id ()
362
+
363
+
364
+ LocalMemory = torch .Tensor | memoryview
365
+
366
+
367
+ class RDMAAction :
368
+ """
369
+ Schedule a bunch of actions at once. This provides an opportunity to
370
+ optimize bulk RDMA transactions without exposing complexity to users.
371
+
372
+ """
373
+
374
+ class RDMAOp (Enum ):
375
+ """Enumeration of RDMA operation types."""
376
+
377
+ READ_INTO = "read_into"
378
+ WRITE_FROM = "write_from"
379
+ FETCH_ADD = "fetch_add"
380
+ COMPARE_AND_SWAP = "compare_and_swap"
381
+
382
+ def __init__ (self ) -> None :
383
+ self ._instructs : List [Tuple [RDMAAction .RDMAOp , RDMABuffer , LocalMemory ]] = []
384
+ self ._memory_dependencies : Dict [Tuple [int , int ], RDMAAction .RDMAOp ] = {}
385
+
386
+ def _check_and_merge_overlapping_range (
387
+ self , addr : int , size : int , op : "RDMAAction.RDMAOp"
388
+ ) -> None :
389
+ """
390
+ Check for overlapping ranges and merge if found.
391
+
392
+ Returns the final range to use (either new_range or expanded merged range).
393
+ Updates self._memory_dependencies in place if merging occurs.
394
+ """
395
+ new_start , new_end = addr , addr + size
396
+
397
+ # Find overlapping range
398
+ overlapping_range = None
399
+ for existing_start , existing_end in self ._memory_dependencies :
400
+ # Check if ranges overlap
401
+ if not (new_end <= existing_start or existing_end <= new_start ):
402
+ overlapping_range = (existing_start , existing_end )
403
+ break
404
+
405
+ # No overlap found - good to go
406
+ if overlapping_range is None :
407
+ self ._memory_dependencies [(new_start , new_end )] = op
408
+ return
409
+
410
+ # Overlap found - merge ranges
411
+ existing_op = self ._memory_dependencies [overlapping_range ]
412
+
413
+ # Merge ops, only safe if neither is write_from at the moment
414
+ if existing_op == self .RDMAOp .WRITE_FROM or op == self .RDMAOp .WRITE_FROM :
415
+ raise ValueError (
416
+ f"Same data range already has a write_from within RDMAAction: { existing_op } vs { op } "
417
+ )
418
+
419
+ # Create expanded range that covers both
420
+ expanded_range = (
421
+ min (overlapping_range [0 ], new_start ),
422
+ max (overlapping_range [1 ], new_end ),
423
+ )
424
+
425
+ # range is unchanged - no need to update
426
+ if expanded_range == (new_start , new_end ):
427
+ return
428
+
429
+ # Update dictionary: remove old range, add expanded range
430
+ del self ._memory_dependencies [overlapping_range ]
431
+ self ._memory_dependencies [expanded_range ] = op
432
+
433
+ # now since merged, possible need to merge again
434
+ return self ._check_and_merge_overlapping_range (
435
+ expanded_range [0 ], expanded_range [1 ] - expanded_range [0 ], op
436
+ )
437
+
438
+ def read_into (self , src : RDMABuffer , dst : LocalMemory | List [LocalMemory ]) -> Self :
439
+ """
440
+ Read from src RDMA buffer into dst memory.
441
+
442
+ Args:
443
+ src: Source RDMA buffer to read from
444
+ dst: Destination local memory to read into
445
+ If dst is a list, it is the concatenation of the data in the list
446
+ """
447
+ # Throw NotImplementedError for lists to simplify logic
448
+ if isinstance (dst , list ):
449
+ raise NotImplementedError ("List destinations not yet supported" )
450
+
451
+ addr , size = _get_addr_and_size (dst )
452
+
453
+ if size < src .size ():
454
+ raise ValueError (
455
+ f"dst memory size ({ size } ) must be >= src buffer size ({ src .size ()} )"
456
+ )
457
+
458
+ self ._check_and_merge_overlapping_range (addr , size , self .RDMAOp .READ_INTO )
459
+
460
+ self ._instructs .append ((self .RDMAOp .READ_INTO , src , dst ))
461
+
462
+ return self
463
+
464
+ def write_from (self , src : RDMABuffer , dst : LocalMemory | List [LocalMemory ]) -> Self :
465
+ """
466
+ Write from dst memory to src RDMA buffer.
467
+
468
+ Args:
469
+ src: Destination RDMA buffer to write to
470
+ dst: Source local memory to write from
471
+ If local is a list, it is the concatenation of the data in the list
472
+ """
473
+ # Throw NotImplementedError for lists to simplify logic
474
+ if isinstance (dst , list ):
475
+ raise NotImplementedError ("List sources not yet supported" )
476
+
477
+ addr , size = _get_addr_and_size (dst )
478
+
479
+ if size > src .size ():
480
+ raise ValueError (
481
+ f"Local memory size ({ size } ) must be <= src buffer size ({ src .size ()} )"
482
+ )
483
+
484
+ self ._check_and_merge_overlapping_range (addr , size , self .RDMAOp .WRITE_FROM )
485
+
486
+ self ._instructs .append ((self .RDMAOp .WRITE_FROM , src , dst ))
487
+
488
+ return self
489
+
490
+ def fetch_add (self , src : RDMABuffer , dst : LocalMemory , add : int ) -> Self :
491
+ """
492
+ Perform atomic fetch-and-add operation on src RDMA buffer.
493
+
494
+ Args:
495
+ src: src RDMA buffer to perform operation on
496
+ dst: Local memory to store the original value
497
+ add: Value to add to the src buffer
498
+
499
+ Atomically:
500
+ *dst = *src
501
+ *src = *src + add
502
+
503
+ Note: src/dst are 8 bytes
504
+ """
505
+ raise NotImplementedError ("Not yet supported" )
506
+
507
+ def compare_and_swap (
508
+ self , src : RDMABuffer , dst : LocalMemory , compare : int , swap : int
509
+ ) -> Self :
510
+ """
511
+ Perform atomic compare-and-swap operation on src RDMA buffer.
512
+
513
+ Args:
514
+ src: src RDMA buffer to perform operation on
515
+ dst: Local memory to store the original value
516
+ compare: Value to compare against
517
+ swap: Value to swap in if comparison succeeds
518
+
519
+ Atomically:
520
+ *dst = *src;
521
+ if (*src == compare) {
522
+ *src = swap
523
+ }
524
+
525
+ Note: src/dst are 8 bytes
526
+ """
527
+ raise NotImplementedError ("Not yet supported" )
528
+
529
+ def submit (self ) -> Future [None ]:
530
+ """
531
+ Schedules the work (can be called multiple times to schedule the same work more than once).
532
+ Future completes when all the work is done.
533
+
534
+ Executes futures for each src actor independently and concurrently for optimal performance.
535
+ """
536
+
537
+ async def submit_all_work () -> None :
538
+ if not self ._instructs :
539
+ return
540
+
541
+ work = defaultdict (list )
542
+
543
+ # Group operations by owner for concurrent execution per owner
544
+ for op , src , dst in self ._instructs :
545
+ if op == self .RDMAOp .READ_INTO :
546
+ fut = src .read_into (dst )
547
+ elif op == self .RDMAOp .WRITE_FROM :
548
+ fut = src .write_from (dst )
549
+ else :
550
+ raise NotImplementedError (f"Unknown RDMA operation: { op } " )
551
+ work [src .owner ].append (fut )
552
+
553
+ # Create a list of tasks, one per owner, that wait for all that owner's futures sequentially
554
+ owner_tasks = []
555
+
556
+ for _ , futures in work .items ():
557
+ # Create a coroutine that processes all futures for a qp sequentially
558
+ async def process_owner_futures (owner_futures_list = futures ):
559
+ """Process all futures for a single qp sequentially"""
560
+ for future in owner_futures_list :
561
+ await future
562
+
563
+ # Convert to PythonTask for Monarch's native concurrency
564
+ owner_task = PythonTask .from_coroutine (process_owner_futures ())
565
+ owner_tasks .append (owner_task )
566
+
567
+ # Spawn all owner tasks concurrently and collect their shared handles
568
+ shared_tasks = [task .spawn () for task in owner_tasks ]
569
+
570
+ # Wait for all owner tasks to complete concurrently
571
+ for shared_task in shared_tasks :
572
+ await shared_task
573
+
574
+ return Future (coro = submit_all_work ())
0 commit comments