Skip to content

Commit 2b9ef30

Browse files
zdevitofacebook-github-bot
authored andcommitted
CanSend knows the sender, ActorMesh does not (meta-pytorch#808)
Summary: Pull Request resolved: meta-pytorch#808 Previously the CanSend capability did not know what actor was sending, whereas casting to an actor mesh required the sending actor. This moves actor_id into CanSend since all of our implementations of CanSend easily already know the actor_id they are sending. This change lets us take an explicit mailbox for sending from PythonActorMesh. This is important for de-futuring actor spawns because the mailbox we are sending from is now the client mailbox not the proc mesh mailbox. We have to be consistent about using it otherwise messages will go out of order. This also refactors all the functionality out of the ActorMeshProtocol such as shape/proc_mesh which do not vary across the implementations. ghstack-source-id: 302126566 exported-using-ghexport Reviewed By: mariusae Differential Revision: D79925866 fbshipit-source-id: 73b5ed81d01ac2c6d0fcc0ac776f301558cee836
1 parent 5c0302d commit 2b9ef30

File tree

22 files changed

+165
-124
lines changed

22 files changed

+165
-124
lines changed

hyperactor/src/cap.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ impl<T: sealed::CanResolveActorRef> CanResolveActorRef for T {}
3737
pub(crate) mod sealed {
3838
use async_trait::async_trait;
3939

40+
use crate::ActorId;
4041
use crate::ActorRef;
4142
use crate::PortId;
4243
use crate::accum::ReducerSpec;
@@ -49,6 +50,7 @@ pub(crate) mod sealed {
4950

5051
pub trait CanSend: Send + Sync {
5152
fn post(&self, dest: PortId, headers: Attrs, data: Serialized);
53+
fn actor_id(&self) -> &ActorId;
5254
}
5355

5456
pub trait CanOpenPort: Send + Sync {

hyperactor/src/mailbox.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,11 +1439,17 @@ impl cap::sealed::CanSend for Mailbox {
14391439
let envelope = MessageEnvelope::new(self.actor_id().clone(), dest, data, headers);
14401440
MailboxSender::post(self, envelope, return_handle);
14411441
}
1442+
fn actor_id(&self) -> &ActorId {
1443+
self.actor_id()
1444+
}
14421445
}
14431446
impl cap::sealed::CanSend for &Mailbox {
14441447
fn post(&self, dest: PortId, headers: Attrs, data: Serialized) {
14451448
cap::sealed::CanSend::post(*self, dest, headers, data)
14461449
}
1450+
fn actor_id(&self) -> &ActorId {
1451+
(**self).actor_id()
1452+
}
14471453
}
14481454

14491455
impl cap::sealed::CanOpenPort for &Mailbox {

hyperactor/src/proc.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,12 +1286,18 @@ impl<A: Actor> cap::sealed::CanSend for Instance<A> {
12861286
let envelope = MessageEnvelope::new(self.self_id().clone(), dest, data, headers);
12871287
self.proc.post(envelope, self.ports.get());
12881288
}
1289+
fn actor_id(&self) -> &ActorId {
1290+
self.self_id()
1291+
}
12891292
}
12901293

12911294
impl<A: Actor> cap::sealed::CanSend for &Instance<A> {
12921295
fn post(&self, dest: PortId, headers: Attrs, data: Serialized) {
12931296
(*self).post(dest, headers, data)
12941297
}
1298+
fn actor_id(&self) -> &ActorId {
1299+
self.self_id()
1300+
}
12951301
}
12961302

12971303
impl<A: Actor> cap::sealed::CanOpenPort for Instance<A> {
@@ -1337,12 +1343,18 @@ impl<A: Actor> cap::sealed::CanSend for Context<'_, A> {
13371343
fn post(&self, dest: PortId, headers: Attrs, data: Serialized) {
13381344
<Instance<A> as cap::sealed::CanSend>::post(self, dest, headers, data)
13391345
}
1346+
fn actor_id(&self) -> &ActorId {
1347+
self.self_id()
1348+
}
13401349
}
13411350

13421351
impl<A: Actor> cap::sealed::CanSend for &Context<'_, A> {
13431352
fn post(&self, dest: PortId, headers: Attrs, data: Serialized) {
13441353
<Instance<A> as cap::sealed::CanSend>::post(self, dest, headers, data)
13451354
}
1355+
fn actor_id(&self) -> &ActorId {
1356+
self.self_id()
1357+
}
13461358
}
13471359

13481360
impl<A: Actor> cap::sealed::CanOpenPort for Context<'_, A> {

hyperactor_mesh/benches/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ fn bench_actor_scaling(c: &mut Criterion) {
6363

6464
actor_mesh
6565
.cast(
66+
client,
6667
all(true_()),
6768
BenchMessage {
6869
step: i as usize,
@@ -154,6 +155,7 @@ fn bench_actor_mesh_message_sizes(c: &mut Criterion) {
154155

155156
actor_mesh
156157
.cast(
158+
client,
157159
all(true_()),
158160
BenchMessage {
159161
step: i as usize,

hyperactor_mesh/examples/dining_philosophers.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ where
204204
if self.is_chopstick_available(chopstick) {
205205
self.grant_chopstick(chopstick, rank);
206206
self.philosophers.cast(
207+
self.philosophers.proc_mesh().client(),
207208
selection_from(self.philosophers.shape(), &[("replica", rank..rank + 1)])?,
208209
PhilosopherMessage::GrantChopstick(chopstick),
209210
)?
@@ -241,6 +242,7 @@ async fn main() -> Result<ExitCode> {
241242
let (dining_message_handle, mut dining_message_rx) = proc_mesh.client().open_port();
242243
actor_mesh
243244
.cast(
245+
proc_mesh.client(),
244246
all(true_()),
245247
PhilosopherMessage::Start(dining_message_handle.bind()),
246248
)

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use std::ops::Deref;
1313

1414
use async_trait::async_trait;
1515
use hyperactor::Actor;
16-
use hyperactor::ActorId;
1716
use hyperactor::ActorRef;
1817
use hyperactor::Bind;
1918
use hyperactor::GangId;
@@ -29,6 +28,7 @@ use hyperactor::actor::RemoteActor;
2928
use hyperactor::attrs::Attrs;
3029
use hyperactor::attrs::declare_attrs;
3130
use hyperactor::cap;
31+
use hyperactor::cap::CanSend;
3232
use hyperactor::mailbox::MailboxSenderError;
3333
use hyperactor::mailbox::PortReceiver;
3434
use hyperactor::message::Castable;
@@ -71,7 +71,6 @@ declare_attrs! {
7171
pub(crate) fn actor_mesh_cast<A, M>(
7272
caps: &impl cap::CanSend,
7373
actor_mesh_id: ActorMeshId,
74-
sender: &ActorId,
7574
comm_actor_ref: &ActorRef<CommActor>,
7675
selection_of_root: Selection,
7776
root_mesh_shape: &Shape,
@@ -89,7 +88,7 @@ where
8988

9089
let message = CastMessageEnvelope::new::<A, M>(
9190
actor_mesh_id.clone(),
92-
sender.clone(),
91+
caps.actor_id().clone(),
9392
cast_mesh_shape.clone(),
9493
message,
9594
)?;
@@ -118,7 +117,6 @@ where
118117
pub(crate) fn cast_to_sliced_mesh<A, M>(
119118
caps: &impl cap::CanSend,
120119
actor_mesh_id: ActorMeshId,
121-
sender: &ActorId,
122120
comm_actor_ref: &ActorRef<CommActor>,
123121
sel_of_sliced: &Selection,
124122
message: M,
@@ -147,7 +145,6 @@ where
147145
actor_mesh_cast::<A, M>(
148146
caps,
149147
actor_mesh_id,
150-
sender,
151148
comm_actor_ref,
152149
sel_of_root,
153150
root_mesh_shape,
@@ -165,20 +162,24 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
165162
/// Cast an `M`-typed message to the ranks selected by `sel` in
166163
/// this ActorMesh.
167164
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
168-
fn cast<M>(&self, selection: Selection, message: M) -> Result<(), CastError>
165+
fn cast<M>(
166+
&self,
167+
sender: &impl CanSend,
168+
selection: Selection,
169+
message: M,
170+
) -> Result<(), CastError>
169171
where
170172
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
171173
M: Castable + RemoteMessage,
172174
{
173175
actor_mesh_cast::<Self::Actor, M>(
174-
self.proc_mesh().client(), // send capability
175-
self.id(), // actor mesh id (destination mesh)
176-
self.proc_mesh().client().actor_id(), // sender
177-
self.proc_mesh().comm_actor(), // comm actor
178-
selection, // the selected actors
179-
self.shape(), // root mesh shape
180-
self.shape(), // cast mesh shape
181-
message, // the message
176+
sender, // send capability
177+
self.id(), // actor mesh id (destination mesh)
178+
self.proc_mesh().comm_actor(), // comm actor
179+
selection, // the selected actors
180+
self.shape(), // root mesh shape
181+
self.shape(), // cast mesh shape
182+
message, // the message
182183
)
183184
}
184185

@@ -412,15 +413,14 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
412413
}
413414

414415
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
415-
fn cast<M>(&self, sel: Selection, message: M) -> Result<(), CastError>
416+
fn cast<M>(&self, sender: &impl CanSend, sel: Selection, message: M) -> Result<(), CastError>
416417
where
417418
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
418419
M: Castable + RemoteMessage,
419420
{
420421
cast_to_sliced_mesh::<A, M>(
421-
/*caps=*/ self.proc_mesh().client(),
422+
/*caps=*/ sender,
422423
/*actor_mesh_id=*/ self.id(),
423-
/*sender=*/ self.proc_mesh().client().actor_id(),
424424
/*comm_actor_ref*/ self.proc_mesh().comm_actor(),
425425
/*sel_of_sliced=*/ &sel,
426426
/*message=*/ message,
@@ -736,7 +736,7 @@ mod tests {
736736
let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
737737
let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
738738
actor_mesh
739-
.cast(sel!(*), Echo("Hello".to_string(), reply_handle.bind()))
739+
.cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind()))
740740
.unwrap();
741741
for _ in 0..4 {
742742
assert_eq!(&reply_receiver.recv().await.unwrap(), "Hello");
@@ -840,7 +840,7 @@ mod tests {
840840
let dont_simulate_error = true;
841841
let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
842842
actor_mesh
843-
.cast(sel!(*), GetRank(dont_simulate_error, reply_handle.bind()))
843+
.cast(proc_mesh.client(), sel!(*), GetRank(dont_simulate_error, reply_handle.bind()))
844844
.unwrap();
845845
let mut ranks = Ranks::new(actor_mesh.shape().slice().len());
846846
while !ranks.is_full() {
@@ -851,6 +851,7 @@ mod tests {
851851
let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
852852
actor_mesh
853853
.cast(
854+
proc_mesh.client(),
854855
sel_from_shape!(actor_mesh.shape(), replica = 0, host = 0),
855856
GetRank(dont_simulate_error, reply_handle.bind()),
856857
)
@@ -958,7 +959,7 @@ mod tests {
958959
let params = CastTestActorParams{ forward_port: tx.bind() };
959960
let actor_mesh: RootActorMesh<CastTestActor> = proc_mesh.spawn("actor", &params).await.unwrap();
960961

961-
actor_mesh.cast(sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap();
962+
actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap();
962963

963964
for _ in 0..num_actors {
964965
assert_eq!(rx.recv().await.unwrap(), CastTestMessage::Forward("abc".to_string()));
@@ -1148,7 +1149,7 @@ mod tests {
11481149
// replying with rank.
11491150
let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
11501151
actor_mesh
1151-
.cast(sel!(*), GetRank(false, reply_handle.bind()))
1152+
.cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
11521153
.unwrap();
11531154
let rank = reply_receiver.recv().await.unwrap();
11541155
assert_eq!(rank, 0);
@@ -1162,7 +1163,7 @@ mod tests {
11621163
// Cast the message.
11631164
let (reply_handle, _) = actor_mesh.open_port();
11641165
actor_mesh
1165-
.cast(sel!(*), GetRank(false, reply_handle.bind()))
1166+
.cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
11661167
.unwrap();
11671168

11681169
// The message will be returned!
@@ -1297,7 +1298,11 @@ mod tests {
12971298
// the message will fail to send.
12981299
assert!(payload.len() > max_frame_length);
12991300
actor_mesh
1300-
.cast(sel!(*), Echo(payload, reply_handle.bind()))
1301+
.cast(
1302+
proc_mesh.client(),
1303+
sel!(*),
1304+
Echo(payload, reply_handle.bind()),
1305+
)
13011306
.unwrap();
13021307

13031308
// The undeliverable message will be turned into a proc event.

hyperactor_mesh/src/comm.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,9 @@ mod tests {
772772
};
773773

774774
let selection = sel!(*);
775-
actor_mesh.cast(selection.clone(), message).unwrap();
775+
actor_mesh
776+
.cast(proc_mesh.client(), selection.clone(), message)
777+
.unwrap();
776778

777779
let mut reply_tos = vec![];
778780
for _ in extent.points() {

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ mod tests {
859859

860860
actors
861861
.cast(
862+
mesh.client(),
862863
sel_from_shape!(actors.shape(), replica = 0),
863864
Error("failmonkey".to_string()),
864865
)

hyperactor_mesh/src/reference.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ impl<A: RemoteActor> ActorMeshRef<A> {
134134
Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
135135
caps,
136136
self.mesh_id.clone(),
137-
caps.mailbox().actor_id(),
138137
&self.comm_actor_ref,
139138
&selection,
140139
message,
@@ -144,7 +143,6 @@ impl<A: RemoteActor> ActorMeshRef<A> {
144143
None => actor_mesh_cast::<A, M>(
145144
caps,
146145
self.mesh_id.clone(),
147-
caps.mailbox().actor_id(),
148146
&self.comm_actor_ref,
149147
selection,
150148
&self.root,

monarch_extension/src/logging.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,23 @@ impl LoggingMeshClient {
7070
}
7171

7272
let forwarder_inner_mesh = self.forwarder_mesh.borrow().map_err(anyhow::Error::msg)?;
73+
74+
let mailbox = forwarder_inner_mesh.proc_mesh().client();
7375
forwarder_inner_mesh
7476
.cast(
77+
mailbox,
7578
Selection::True,
7679
LogForwardMessage::SetMode { stream_to_client },
7780
)
7881
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
7982

8083
let logger_inner_mesh = self.logger_mesh.borrow().map_err(anyhow::Error::msg)?;
8184
logger_inner_mesh
82-
.cast(Selection::True, LoggerRuntimeMessage::SetLogging { level })
85+
.cast(
86+
mailbox,
87+
Selection::True,
88+
LoggerRuntimeMessage::SetLogging { level },
89+
)
8390
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
8491

8592
self.client_actor

0 commit comments

Comments
 (0)