@@ -15,13 +15,14 @@ use anyhow::bail;
15
15
use anyhow:: ensure;
16
16
use async_trait:: async_trait;
17
17
use cxx:: CxxVector ;
18
+ use derivative:: Derivative ;
18
19
use hyperactor:: Actor ;
19
20
use hyperactor:: HandleClient ;
20
21
use hyperactor:: Handler ;
22
+ use hyperactor:: Instance ;
21
23
use hyperactor:: Named ;
22
24
use hyperactor:: actor:: ActorHandle ;
23
25
use hyperactor:: forward;
24
- use hyperactor:: mailbox:: Mailbox ;
25
26
use hyperactor:: mailbox:: OncePortHandle ;
26
27
use hyperactor:: mailbox:: OncePortReceiver ;
27
28
use parking_lot:: Mutex ;
@@ -463,10 +464,12 @@ impl Work for CommWork {
463
464
}
464
465
}
465
466
466
- #[ derive( Debug ) ]
467
+ #[ derive( Derivative ) ]
468
+ #[ derivative( Debug ) ]
467
469
pub struct CommBackend {
470
+ #[ derivative( Debug = "ignore" ) ]
471
+ instance : Instance < ( ) > , // The actor that represents this object.
468
472
comm : Arc < ActorHandle < NcclCommActor > > ,
469
- mailbox : Mailbox ,
470
473
rank : usize ,
471
474
// Size of group. This is less than or equal to world_size.
472
475
group_size : usize ,
@@ -477,8 +480,8 @@ pub struct CommBackend {
477
480
478
481
impl CommBackend {
479
482
pub fn new (
483
+ instance : Instance < ( ) > ,
480
484
comm : Arc < ActorHandle < NcclCommActor > > ,
481
- mailbox : Mailbox ,
482
485
rank : usize ,
483
486
group_size : usize ,
484
487
world_size : usize ,
@@ -488,8 +491,8 @@ impl CommBackend {
488
491
"Group must be smaller or equal to the world size"
489
492
) ;
490
493
Self {
494
+ instance,
491
495
comm,
492
- mailbox,
493
496
rank,
494
497
group_size,
495
498
world_size,
@@ -570,7 +573,7 @@ impl Backend for CommBackend {
570
573
let cell = TensorCell :: new ( unsafe { as_singleton ( tensors. as_slice ( ) ) ?. clone_unsafe ( ) } ) ;
571
574
572
575
// Call into `NcclCommActor`.
573
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
576
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
574
577
self . comm . send ( CommMessage :: AllReduce (
575
578
cell. clone ( ) ,
576
579
convert_reduce_op ( opts. reduce_op ) ?,
@@ -603,7 +606,7 @@ impl Backend for CommBackend {
603
606
}
604
607
605
608
// Call into `NcclCommActor`.
606
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
609
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
607
610
// This is not implemented in this function because the broadcasts we need
608
611
// to create will change their behavior based on rank.
609
612
self . comm . send ( CommMessage :: AllGather (
@@ -631,7 +634,7 @@ impl Backend for CommBackend {
631
634
let input_cell = TensorCell :: new ( unsafe { input. clone_unsafe ( ) } ) ;
632
635
633
636
// Call into `NcclCommActor`.
634
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
637
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
635
638
self . comm . send ( CommMessage :: AllGatherIntoTensor (
636
639
output_cell. clone ( ) ,
637
640
input_cell. clone ( ) ,
@@ -645,7 +648,7 @@ impl Backend for CommBackend {
645
648
646
649
async fn barrier ( & self , _opts : BarrierOptions ) -> Result < Box < dyn Work < Error = anyhow:: Error > > > {
647
650
// Call into `NcclCommActor`.
648
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
651
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
649
652
self . comm
650
653
// There's no native barrier op in nccl, so impl via all-reduce.
651
654
. send ( CommMessage :: Barrier ( Stream :: get_current_stream ( ) , tx) ) ?;
@@ -663,7 +666,7 @@ impl Backend for CommBackend {
663
666
let input_cell = TensorCell :: new ( unsafe { input. clone_unsafe ( ) } ) ;
664
667
665
668
// Call into `NcclCommActor`.
666
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
669
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
667
670
self . comm . send ( CommMessage :: Reduce (
668
671
input_cell. clone ( ) ,
669
672
convert_reduce_op ( opts. reduce_op ) ?,
@@ -697,7 +700,7 @@ impl Backend for CommBackend {
697
700
}
698
701
699
702
// Call into `NcclCommActor`.
700
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
703
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
701
704
self . comm . send ( CommMessage :: ReduceScatterTensor (
702
705
output_cell. clone ( ) ,
703
706
input_cell. clone ( ) ,
@@ -726,7 +729,7 @@ impl Backend for CommBackend {
726
729
let cell = TensorCell :: new ( unsafe { as_singleton ( tensors. as_slice ( ) ) ?. clone_unsafe ( ) } ) ;
727
730
728
731
// Call into `NcclCommActor`.
729
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
732
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
730
733
self . comm . send ( CommMessage :: Send (
731
734
cell. clone ( ) ,
732
735
dst_rank,
@@ -752,7 +755,7 @@ impl Backend for CommBackend {
752
755
let cell = TensorCell :: new ( unsafe { as_singleton ( tensors. as_slice ( ) ) ?. clone_unsafe ( ) } ) ;
753
756
754
757
// Call into `NcclCommActor`.
755
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
758
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
756
759
self . comm . send ( CommMessage :: Recv (
757
760
cell. clone ( ) ,
758
761
src_rank,
@@ -782,7 +785,7 @@ impl Backend for CommBackend {
782
785
assert_type_and_sizes_match ( outputs. as_slice ( ) , input. scalar_type ( ) , & input. sizes ( ) ) ?;
783
786
784
787
// Call into `NcclCommActor`.
785
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
788
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
786
789
let mut messages = vec ! [ ] ;
787
790
// All ranks other than the root Recv, and the root rank calls Send.
788
791
if self . rank == root {
@@ -795,7 +798,7 @@ impl Backend for CommBackend {
795
798
}
796
799
for ( r, output) in output_cells. clone ( ) . into_iter ( ) . enumerate ( ) {
797
800
if r != root {
798
- let ( tx_recv, _rx_recv) = self . mailbox . open_once_port ( ) ;
801
+ let ( tx_recv, _rx_recv) = self . instance . open_once_port ( ) ;
799
802
messages. push ( CommMessage :: Recv (
800
803
output,
801
804
r as i32 ,
@@ -814,7 +817,7 @@ impl Backend for CommBackend {
814
817
output_cells. len( )
815
818
) ) ;
816
819
}
817
- let ( tx_send, _rx_send) = self . mailbox . open_once_port ( ) ;
820
+ let ( tx_send, _rx_send) = self . instance . open_once_port ( ) ;
818
821
messages. push ( CommMessage :: Send (
819
822
input_cell. clone ( ) ,
820
823
root as i32 ,
@@ -853,7 +856,7 @@ impl Backend for CommBackend {
853
856
assert_type_and_sizes_match ( inputs. as_slice ( ) , output. scalar_type ( ) , & output. sizes ( ) ) ?;
854
857
855
858
// Call into `NcclCommActor`.
856
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
859
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
857
860
let mut messages = vec ! [ ] ;
858
861
// Implementation is the inverse set of messages from gather, where all ranks
859
862
// other than the root Send, and the root rank calls Recv.
@@ -867,7 +870,7 @@ impl Backend for CommBackend {
867
870
}
868
871
for ( r, input) in input_cells. clone ( ) . into_iter ( ) . enumerate ( ) {
869
872
if r != root {
870
- let ( tx_send, _rx_send) = self . mailbox . open_once_port ( ) ;
873
+ let ( tx_send, _rx_send) = self . instance . open_once_port ( ) ;
871
874
messages. push ( CommMessage :: Send (
872
875
input,
873
876
r as i32 ,
@@ -886,7 +889,7 @@ impl Backend for CommBackend {
886
889
input_cells. len( )
887
890
) ) ;
888
891
}
889
- let ( tx_recv, _rx_recv) = self . mailbox . open_once_port ( ) ;
892
+ let ( tx_recv, _rx_recv) = self . instance . open_once_port ( ) ;
890
893
messages. push ( CommMessage :: Recv (
891
894
output_cell. clone ( ) ,
892
895
root as i32 ,
@@ -916,7 +919,7 @@ impl Backend for CommBackend {
916
919
let cell = TensorCell :: new ( unsafe { as_singleton ( tensors. as_slice ( ) ) ?. clone_unsafe ( ) } ) ;
917
920
918
921
// Call into `NcclCommActor`.
919
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
922
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
920
923
self . comm . send ( CommMessage :: Broadcast (
921
924
cell. clone ( ) ,
922
925
opts. root_rank ,
@@ -940,7 +943,7 @@ impl Backend for CommBackend {
940
943
let input_cell = TensorCell :: new ( unsafe { input_buffer. clone_unsafe ( ) } ) ;
941
944
942
945
// Call into `NcclCommActor`.
943
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
946
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
944
947
self . comm . send ( CommMessage :: AllToAllSingle (
945
948
output_cell. clone ( ) ,
946
949
input_cell. clone ( ) ,
@@ -995,8 +998,8 @@ impl Backend for CommBackend {
995
998
for r in 0 ..output_tensors. len ( ) {
996
999
let output_cell = & output_cells[ r] ;
997
1000
let input_cell = & input_cells[ r] ;
998
- let ( tx_send, _rx_send) = self . mailbox . open_once_port ( ) ;
999
- let ( tx_recv, _rx_recv) = self . mailbox . open_once_port ( ) ;
1001
+ let ( tx_send, _rx_send) = self . instance . open_once_port ( ) ;
1002
+ let ( tx_recv, _rx_recv) = self . instance . open_once_port ( ) ;
1000
1003
messages. push ( CommMessage :: Send (
1001
1004
input_cell. clone ( ) ,
1002
1005
r as i32 ,
@@ -1010,7 +1013,7 @@ impl Backend for CommBackend {
1010
1013
tx_recv,
1011
1014
) ) ;
1012
1015
}
1013
- let ( tx, rx) = self . mailbox . open_once_port ( ) ;
1016
+ let ( tx, rx) = self . instance . open_once_port ( ) ;
1014
1017
self . comm . send ( CommMessage :: Group ( messages, stream, tx) ) ?;
1015
1018
let mut all_cells = vec ! [ ] ;
1016
1019
all_cells. extend ( output_cells) ;
0 commit comments