Skip to content

Commit 9700834

Browse files
pzhan9meta-codesync[bot]
authored andcommitted
Use cx: Instance<()> to replace mailbox: Mailbox in CommBackend (#1374)
Summary: Pull Request resolved: #1374 This diff is part of the effect to adding sequencing logic to sender actor. See D83371710 for details. This diff specifically change `CommBackend`'s `mailbox: Mailbox` field to `cx: Instance<()>`, so it can later be plumbed to the `context::Actor` changes. Reviewed By: pablorfb-meta Differential Revision: D83530618 fbshipit-source-id: 716133f4c63cc51cea247b204debc82cac467539
1 parent e15c632 commit 9700834

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

monarch_tensor_worker/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ async-trait = "0.1.86"
1313
bincode = "1.3.3"
1414
clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] }
1515
cxx = "1.0.119"
16+
derivative = "2.2"
1617
derive_more = { version = "1.0.0", features = ["full"] }
1718
futures = { version = "0.3.31", features = ["async-await", "compat"] }
1819
hyperactor = { version = "0.0.0", path = "../hyperactor" }

monarch_tensor_worker/src/comm.rs

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ use anyhow::bail;
1515
use anyhow::ensure;
1616
use async_trait::async_trait;
1717
use cxx::CxxVector;
18+
use derivative::Derivative;
1819
use hyperactor::Actor;
1920
use hyperactor::HandleClient;
2021
use hyperactor::Handler;
22+
use hyperactor::Instance;
2123
use hyperactor::Named;
2224
use hyperactor::actor::ActorHandle;
2325
use hyperactor::forward;
24-
use hyperactor::mailbox::Mailbox;
2526
use hyperactor::mailbox::OncePortHandle;
2627
use hyperactor::mailbox::OncePortReceiver;
2728
use parking_lot::Mutex;
@@ -463,10 +464,12 @@ impl Work for CommWork {
463464
}
464465
}
465466

466-
#[derive(Debug)]
467+
#[derive(Derivative)]
468+
#[derivative(Debug)]
467469
pub struct CommBackend {
470+
#[derivative(Debug = "ignore")]
471+
instance: Instance<()>, // The actor that represents this object.
468472
comm: Arc<ActorHandle<NcclCommActor>>,
469-
mailbox: Mailbox,
470473
rank: usize,
471474
// Size of group. This is less than or equal to world_size.
472475
group_size: usize,
@@ -477,8 +480,8 @@ pub struct CommBackend {
477480

478481
impl CommBackend {
479482
pub fn new(
483+
instance: Instance<()>,
480484
comm: Arc<ActorHandle<NcclCommActor>>,
481-
mailbox: Mailbox,
482485
rank: usize,
483486
group_size: usize,
484487
world_size: usize,
@@ -488,8 +491,8 @@ impl CommBackend {
488491
"Group must be smaller or equal to the world size"
489492
);
490493
Self {
494+
instance,
491495
comm,
492-
mailbox,
493496
rank,
494497
group_size,
495498
world_size,
@@ -570,7 +573,7 @@ impl Backend for CommBackend {
570573
let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() });
571574

572575
// Call into `NcclCommActor`.
573-
let (tx, rx) = self.mailbox.open_once_port();
576+
let (tx, rx) = self.instance.open_once_port();
574577
self.comm.send(CommMessage::AllReduce(
575578
cell.clone(),
576579
convert_reduce_op(opts.reduce_op)?,
@@ -603,7 +606,7 @@ impl Backend for CommBackend {
603606
}
604607

605608
// Call into `NcclCommActor`.
606-
let (tx, rx) = self.mailbox.open_once_port();
609+
let (tx, rx) = self.instance.open_once_port();
607610
// This is not implemented in this function because the broadcasts we need
608611
// to create will change their behavior based on rank.
609612
self.comm.send(CommMessage::AllGather(
@@ -631,7 +634,7 @@ impl Backend for CommBackend {
631634
let input_cell = TensorCell::new(unsafe { input.clone_unsafe() });
632635

633636
// Call into `NcclCommActor`.
634-
let (tx, rx) = self.mailbox.open_once_port();
637+
let (tx, rx) = self.instance.open_once_port();
635638
self.comm.send(CommMessage::AllGatherIntoTensor(
636639
output_cell.clone(),
637640
input_cell.clone(),
@@ -645,7 +648,7 @@ impl Backend for CommBackend {
645648

646649
async fn barrier(&self, _opts: BarrierOptions) -> Result<Box<dyn Work<Error = anyhow::Error>>> {
647650
// Call into `NcclCommActor`.
648-
let (tx, rx) = self.mailbox.open_once_port();
651+
let (tx, rx) = self.instance.open_once_port();
649652
self.comm
650653
// There's no native barrier op in nccl, so impl via all-reduce.
651654
.send(CommMessage::Barrier(Stream::get_current_stream(), tx))?;
@@ -663,7 +666,7 @@ impl Backend for CommBackend {
663666
let input_cell = TensorCell::new(unsafe { input.clone_unsafe() });
664667

665668
// Call into `NcclCommActor`.
666-
let (tx, rx) = self.mailbox.open_once_port();
669+
let (tx, rx) = self.instance.open_once_port();
667670
self.comm.send(CommMessage::Reduce(
668671
input_cell.clone(),
669672
convert_reduce_op(opts.reduce_op)?,
@@ -697,7 +700,7 @@ impl Backend for CommBackend {
697700
}
698701

699702
// Call into `NcclCommActor`.
700-
let (tx, rx) = self.mailbox.open_once_port();
703+
let (tx, rx) = self.instance.open_once_port();
701704
self.comm.send(CommMessage::ReduceScatterTensor(
702705
output_cell.clone(),
703706
input_cell.clone(),
@@ -726,7 +729,7 @@ impl Backend for CommBackend {
726729
let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() });
727730

728731
// Call into `NcclCommActor`.
729-
let (tx, rx) = self.mailbox.open_once_port();
732+
let (tx, rx) = self.instance.open_once_port();
730733
self.comm.send(CommMessage::Send(
731734
cell.clone(),
732735
dst_rank,
@@ -752,7 +755,7 @@ impl Backend for CommBackend {
752755
let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() });
753756

754757
// Call into `NcclCommActor`.
755-
let (tx, rx) = self.mailbox.open_once_port();
758+
let (tx, rx) = self.instance.open_once_port();
756759
self.comm.send(CommMessage::Recv(
757760
cell.clone(),
758761
src_rank,
@@ -782,7 +785,7 @@ impl Backend for CommBackend {
782785
assert_type_and_sizes_match(outputs.as_slice(), input.scalar_type(), &input.sizes())?;
783786

784787
// Call into `NcclCommActor`.
785-
let (tx, rx) = self.mailbox.open_once_port();
788+
let (tx, rx) = self.instance.open_once_port();
786789
let mut messages = vec![];
787790
// All ranks other than the root Recv, and the root rank calls Send.
788791
if self.rank == root {
@@ -795,7 +798,7 @@ impl Backend for CommBackend {
795798
}
796799
for (r, output) in output_cells.clone().into_iter().enumerate() {
797800
if r != root {
798-
let (tx_recv, _rx_recv) = self.mailbox.open_once_port();
801+
let (tx_recv, _rx_recv) = self.instance.open_once_port();
799802
messages.push(CommMessage::Recv(
800803
output,
801804
r as i32,
@@ -814,7 +817,7 @@ impl Backend for CommBackend {
814817
output_cells.len()
815818
));
816819
}
817-
let (tx_send, _rx_send) = self.mailbox.open_once_port();
820+
let (tx_send, _rx_send) = self.instance.open_once_port();
818821
messages.push(CommMessage::Send(
819822
input_cell.clone(),
820823
root as i32,
@@ -853,7 +856,7 @@ impl Backend for CommBackend {
853856
assert_type_and_sizes_match(inputs.as_slice(), output.scalar_type(), &output.sizes())?;
854857

855858
// Call into `NcclCommActor`.
856-
let (tx, rx) = self.mailbox.open_once_port();
859+
let (tx, rx) = self.instance.open_once_port();
857860
let mut messages = vec![];
858861
// Implementation is the inverse set of messages from gather, where all ranks
859862
// other than the root Send, and the root rank calls Recv.
@@ -867,7 +870,7 @@ impl Backend for CommBackend {
867870
}
868871
for (r, input) in input_cells.clone().into_iter().enumerate() {
869872
if r != root {
870-
let (tx_send, _rx_send) = self.mailbox.open_once_port();
873+
let (tx_send, _rx_send) = self.instance.open_once_port();
871874
messages.push(CommMessage::Send(
872875
input,
873876
r as i32,
@@ -886,7 +889,7 @@ impl Backend for CommBackend {
886889
input_cells.len()
887890
));
888891
}
889-
let (tx_recv, _rx_recv) = self.mailbox.open_once_port();
892+
let (tx_recv, _rx_recv) = self.instance.open_once_port();
890893
messages.push(CommMessage::Recv(
891894
output_cell.clone(),
892895
root as i32,
@@ -916,7 +919,7 @@ impl Backend for CommBackend {
916919
let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() });
917920

918921
// Call into `NcclCommActor`.
919-
let (tx, rx) = self.mailbox.open_once_port();
922+
let (tx, rx) = self.instance.open_once_port();
920923
self.comm.send(CommMessage::Broadcast(
921924
cell.clone(),
922925
opts.root_rank,
@@ -940,7 +943,7 @@ impl Backend for CommBackend {
940943
let input_cell = TensorCell::new(unsafe { input_buffer.clone_unsafe() });
941944

942945
// Call into `NcclCommActor`.
943-
let (tx, rx) = self.mailbox.open_once_port();
946+
let (tx, rx) = self.instance.open_once_port();
944947
self.comm.send(CommMessage::AllToAllSingle(
945948
output_cell.clone(),
946949
input_cell.clone(),
@@ -995,8 +998,8 @@ impl Backend for CommBackend {
995998
for r in 0..output_tensors.len() {
996999
let output_cell = &output_cells[r];
9971000
let input_cell = &input_cells[r];
998-
let (tx_send, _rx_send) = self.mailbox.open_once_port();
999-
let (tx_recv, _rx_recv) = self.mailbox.open_once_port();
1001+
let (tx_send, _rx_send) = self.instance.open_once_port();
1002+
let (tx_recv, _rx_recv) = self.instance.open_once_port();
10001003
messages.push(CommMessage::Send(
10011004
input_cell.clone(),
10021005
r as i32,
@@ -1010,7 +1013,7 @@ impl Backend for CommBackend {
10101013
tx_recv,
10111014
));
10121015
}
1013-
let (tx, rx) = self.mailbox.open_once_port();
1016+
let (tx, rx) = self.instance.open_once_port();
10141017
self.comm.send(CommMessage::Group(messages, stream, tx))?;
10151018
let mut all_cells = vec![];
10161019
all_cells.extend(output_cells);

monarch_tensor_worker/src/stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ use hyperactor::PortHandle;
3232
use hyperactor::actor::ActorHandle;
3333
use hyperactor::data::Serialized;
3434
use hyperactor::forward;
35-
use hyperactor::mailbox::Mailbox;
3635
use hyperactor::mailbox::OncePortHandle;
3736
use hyperactor::mailbox::PortReceiver;
3837
use hyperactor::proc::Proc;
@@ -829,9 +828,10 @@ impl StreamActor {
829828
// it to create a new torch group.
830829
let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
831830
let group_size = ranks.len();
831+
let (child_instance, _) = cx.child()?;
832832
let backend = CommBackend::new(
833+
child_instance,
833834
comm,
834-
Mailbox::new_detached(cx.self_id().clone()),
835835
self.rank,
836836
group_size,
837837
self.world_size,

0 commit comments

Comments
 (0)