Skip to content

Commit 7d767cf

Browse files
pzhan9facebook-github-bot
authored andcommitted
Pass cast rank to python actor (#747)
Summary: Pull Request resolved: #747 When casting to a sliced mesh, the actors rank on the sliced mesh is different from the rank on the root mesh. Currently, the root mesh's rank is passed to Python actor. [That is wrong](https://www.internalfb.com/diff/D78355743?dst_version_fbid=1460199675405905&transaction_fbid=1279216557057466). We need to pass the rank on the cast mesh. If the cast mesh is a sliced mesh, then it should be sliced mesh rank. This diff fixes this. Reviewed By: mariusae Differential Revision: D79530146 fbshipit-source-id: 0721964e3c532b82d91e48eb189ab97fc7d5435a
1 parent bc6f71e commit 7d767cf

File tree

6 files changed

+147
-37
lines changed

6 files changed

+147
-37
lines changed

controller/src/lib.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ use ndslice::Selection;
6565
use ndslice::Shape;
6666
use ndslice::Slice;
6767
use ndslice::reshape::Limit;
68-
use ndslice::reshape::ReshapeSliceExt;
68+
use ndslice::reshape::ReshapeShapeExt;
6969
use ndslice::selection::dsl;
7070
use ndslice::shape::Range;
7171
use serde::Deserialize;
@@ -425,6 +425,14 @@ impl ControllerMessageHandler for ControllerActor {
425425
}),
426426
};
427427

428+
let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
429+
// Use a made-up label to create a fake shape. This shape is used by
430+
// comm actor to determine the cast rank. Cast rank is not used by
431+
// DeviceMesh, but we still need a shape there to make the logic happy.
432+
let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
433+
.reshape(Limit::from(CASTING_FANOUT_SIZE))
434+
.shape;
435+
428436
let message = CastMessageEnvelope::from_serialized(
429437
ActorMeshId(
430438
ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
@@ -439,15 +447,10 @@ impl ControllerMessageHandler for ControllerActor {
439447
.name()
440448
.to_string(),
441449
),
442-
// Not reflective of the actual shape, but this is never actually used.
443-
Shape::unity(),
450+
made_up_shape,
444451
message,
445452
);
446453

447-
let slice = Slice::new(0usize, vec![self.world_size], vec![1])
448-
.unwrap()
449-
.reshape_with_limit(Limit::from(CASTING_FANOUT_SIZE));
450-
451454
self.comm_actor_ref.send(
452455
cx,
453456
CastMessage {

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pub(crate) fn actor_mesh_cast<A, M>(
7575
comm_actor_ref: &ActorRef<CommActor>,
7676
selection_of_root: Selection,
7777
root_mesh_shape: &Shape,
78+
cast_mesh_shape: &Shape,
7879
message: M,
7980
) -> Result<(), CastError>
8081
where
@@ -89,10 +90,13 @@ where
8990
let message = CastMessageEnvelope::new::<A, M>(
9091
actor_mesh_id.clone(),
9192
sender.clone(),
92-
root_mesh_shape.clone(),
93+
cast_mesh_shape.clone(),
9394
message,
9495
)?;
9596
let cast_message = CastMessage {
97+
// Note: `dest` is on the root mesh' shape, which could be different
98+
// from the cast mesh's shape if the cast is on a view, e.g. a sliced
99+
// mesh.
96100
dest: Uslice {
97101
slice: root_mesh_shape.slice().clone(),
98102
selection: selection_of_root,
@@ -147,6 +151,7 @@ where
147151
comm_actor_ref,
148152
sel_of_root,
149153
root_mesh_shape,
154+
sliced_shape,
150155
message,
151156
)
152157
}
@@ -172,6 +177,7 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
172177
self.proc_mesh().comm_actor(), // comm actor
173178
selection, // the selected actors
174179
self.shape(), // root mesh shape
180+
self.shape(), // cast mesh shape
175181
message, // the message
176182
)
177183
}
@@ -419,7 +425,7 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
419425
/*sel_of_sliced=*/ &sel,
420426
/*message=*/ message,
421427
/*sliced_shape=*/ self.shape(),
422-
/*base_shape=*/ self.0.shape(),
428+
/*root_mesh_shape=*/ self.0.shape(),
423429
)
424430
}
425431
}

hyperactor_mesh/src/comm.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,14 @@ impl CommActor {
249249

250250
// Deliver message here, if necessary.
251251
if deliver_here {
252+
let rank_on_root_mesh = mode.self_rank(cx.self_id());
253+
let cast_rank = message.relative_rank(rank_on_root_mesh)?;
254+
let cast_shape = message.shape();
252255
let mut headers = cx.headers().clone();
253256
set_cast_info_on_headers(
254257
&mut headers,
255-
mode.self_rank(cx.self_id()),
256-
message.shape().clone(),
258+
cast_rank,
259+
cast_shape.clone(),
257260
message.sender().clone(),
258261
);
259262
cx.post(

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use hyperactor::message::Castable;
2121
use hyperactor::message::ErasedUnbound;
2222
use hyperactor::message::IndexedErasedUnbound;
2323
use hyperactor::reference::ActorId;
24+
use ndslice::Extent;
2425
use ndslice::Shape;
2526
use ndslice::Slice;
2627
use ndslice::selection::Selection;
@@ -120,6 +121,39 @@ impl CastMessageEnvelope {
120121
&self.shape
121122
}
122123

124+
/// Given a rank in the root shape, return the corresponding point in the
125+
/// provided shape, which is a view of the root shape.
126+
pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
127+
let shape = self.shape();
128+
let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
129+
anyhow::anyhow!(
130+
"fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
131+
rank_on_root_mesh,
132+
e,
133+
shape,
134+
)
135+
})?;
136+
let extent =
137+
Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
138+
anyhow::anyhow!(
139+
"fail to calculate extent for root rank {} due to error: {}; shape is {}",
140+
rank_on_root_mesh,
141+
e,
142+
shape,
143+
)
144+
})?;
145+
let point = extent.point(coords).map_err(|e| {
146+
anyhow::anyhow!(
147+
"fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
148+
rank_on_root_mesh,
149+
e,
150+
extent,
151+
shape,
152+
)
153+
})?;
154+
Ok(point.rank())
155+
}
156+
123157
/// The unique key used to indicate the stream to which to deliver this message.
124158
/// Concretely, the comm actors along the path should use this key to manage
125159
/// sequence numbers and reorder buffers.
@@ -203,9 +237,14 @@ declare_attrs! {
203237
pub attr CAST_ORIGINATING_SENDER: ActorId;
204238
}
205239

206-
pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, sender: ActorId) {
207-
headers.set(CAST_RANK, rank);
208-
headers.set(CAST_SHAPE, shape);
240+
pub fn set_cast_info_on_headers(
241+
headers: &mut Attrs,
242+
cast_rank: usize,
243+
cast_shape: Shape,
244+
sender: ActorId,
245+
) {
246+
headers.set(CAST_RANK, cast_rank);
247+
headers.set(CAST_SHAPE, cast_shape);
209248
headers.set(CAST_ORIGINATING_SENDER, sender);
210249
}
211250

hyperactor_mesh/src/reference.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ impl<A: RemoteActor> ActorMeshRef<A> {
148148
&self.comm_actor_ref,
149149
selection,
150150
&self.root,
151+
&self.root,
151152
message,
152153
),
153154
}

python/tests/_monarch/test_actor_mesh.py

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
import pickle
10-
from typing import Any, Callable, Coroutine, Iterable, List, TYPE_CHECKING
10+
from typing import Any, Callable, cast, Coroutine, Iterable, List, TYPE_CHECKING
1111

1212
import monarch
1313
import pytest
@@ -57,6 +57,12 @@ async def allocate() -> ProcMesh:
5757

5858

5959
class MyActor:
60+
def __init__(self) -> None:
61+
# Note: for the same actor, its rank on the root mesh could be different
62+
# from its rank on the mesh it is cast to. This is because the cast
63+
# mesh could be a sliced mesh.
64+
self._rank_on_root_mesh: int = -1
65+
6066
async def handle(
6167
self,
6268
mailbox: Mailbox,
@@ -68,8 +74,21 @@ async def handle(
6874
local_state: Iterable[Any],
6975
response_port: "PortProtocol[Any]",
7076
) -> None:
71-
assert rank is not None
72-
response_port.send(f"rank: {rank}")
77+
match method:
78+
case MethodSpecifier.Init():
79+
# Since this actor is spawn from the root proc mesh, the rank
80+
# passed from init should be the rank on the root mesh.
81+
self._rank_on_root_mesh = rank
82+
response_port.send(None)
83+
return None
84+
case MethodSpecifier.ReturnsResponse(name=_):
85+
response_port.send(self._rank_on_root_mesh)
86+
return None
87+
case MethodSpecifier.ExplicitPort(name=_):
88+
response_port.exception(
89+
NotImplementedError("ExplicitPort is not supported yet")
90+
)
91+
return None
7392

7493

7594
# TODO - re-enable after resolving T232206970
@@ -95,35 +114,70 @@ async def run() -> None:
95114
run()
96115

97116

98-
async def verify_cast(
117+
async def spawn_actor_mesh(proc_mesh: ProcMesh) -> PythonActorMesh:
118+
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
119+
# init actors to record their root ranks
120+
receiver: PortReceiver
121+
handle, receiver = proc_mesh.client.open_port()
122+
port_ref = handle.bind()
123+
124+
message = PythonMessage(
125+
PythonMessageKind.CallMethod(MethodSpecifier.Init(), port_ref),
126+
pickle.dumps(None),
127+
)
128+
actor_mesh.cast(Selection.all(), message)
129+
# wait for init to complete
130+
for _ in range(len(actor_mesh.shape.ndslice)):
131+
await receiver.recv_task()
132+
133+
return actor_mesh
134+
135+
136+
async def cast_to_call(
137+
actor_mesh: PythonActorMesh | PythonActorMeshRef,
138+
mailbox: Mailbox,
139+
message: PythonMessage,
140+
) -> None:
141+
sel = Selection.all()
142+
if isinstance(actor_mesh, PythonActorMesh):
143+
actor_mesh.cast(sel, message)
144+
elif isinstance(actor_mesh, PythonActorMeshRef):
145+
actor_mesh.cast(mailbox, sel, message)
146+
147+
148+
async def verify_cast_to_call(
99149
actor_mesh: PythonActorMesh | PythonActorMeshRef,
100150
mailbox: Mailbox,
101-
cast_ranks: List[int],
151+
root_ranks: List[int],
102152
) -> None:
103153
receiver: PortReceiver
104154
handle, receiver = mailbox.open_port()
105155
port_ref = handle.bind()
106156

157+
# Now send the real message
107158
message = PythonMessage(
108159
PythonMessageKind.CallMethod(MethodSpecifier.ReturnsResponse("echo"), port_ref),
109160
pickle.dumps("ping"),
110161
)
111-
sel = Selection.from_string("*")
112-
if isinstance(actor_mesh, PythonActorMesh):
113-
actor_mesh.cast(sel, message)
114-
elif isinstance(actor_mesh, PythonActorMeshRef):
115-
actor_mesh.cast(mailbox, sel, message)
162+
await cast_to_call(actor_mesh, mailbox, message)
116163

117164
rcv_ranks = []
118-
for _ in range(len(cast_ranks)):
165+
for _ in range(len(root_ranks)):
119166
message = await receiver.recv_task()
120167
result_kind = message.kind
121168
assert isinstance(result_kind, PythonMessageKind.Result)
122-
rank = result_kind.rank
123-
assert rank is not None
124-
rcv_ranks.append(rank)
125-
rcv_ranks.sort()
126-
assert rcv_ranks == cast_ranks
169+
cast_rank = result_kind.rank
170+
assert cast_rank is not None
171+
root_rank = cast(int, pickle.loads(message.message))
172+
rcv_ranks.append((cast_rank, root_rank))
173+
rcv_ranks.sort(key=lambda pair: pair[0])
174+
recv_cast_ranks, recv_root_ranks = zip(*rcv_ranks)
175+
assert recv_root_ranks == tuple(
176+
root_ranks
177+
), f"recv_root_ranks={recv_root_ranks}, root_ranks={tuple(root_ranks)}"
178+
assert recv_cast_ranks == tuple(
179+
range(len(root_ranks))
180+
), f"recv_cast_ranks={recv_cast_ranks}, root_ranks={tuple(root_ranks)}"
127181
# verify no more messages are received
128182
with pytest.raises(TimeoutError):
129183
await receiver.recv_task().with_timeout(1)
@@ -136,8 +190,8 @@ async def test_cast_handle() -> None:
136190
@run_on_tokio
137191
async def run() -> None:
138192
proc_mesh = await allocate()
139-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
140-
await verify_cast(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))
193+
actor_mesh = await spawn_actor_mesh(proc_mesh)
194+
await verify_cast_to_call(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))
141195

142196
await proc_mesh.stop_nonblocking()
143197

@@ -151,9 +205,11 @@ async def test_cast_ref() -> None:
151205
@run_on_tokio
152206
async def run() -> None:
153207
proc_mesh = await allocate()
154-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
208+
actor_mesh = await spawn_actor_mesh(proc_mesh)
155209
actor_mesh_ref = actor_mesh.bind()
156-
await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8)))
210+
await verify_cast_to_call(
211+
actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8))
212+
)
157213

158214
await proc_mesh.stop_nonblocking()
159215

@@ -184,7 +240,7 @@ async def verify_slice(
184240
assert (
185241
sliced_shape.ranks() == replica_0_ranks + replica_1_ranks
186242
), f"left is {sliced_shape.ranks()}"
187-
await verify_cast(sliced_mesh, mailbox, sliced_shape.ranks())
243+
await verify_cast_to_call(sliced_mesh, mailbox, sliced_shape.ranks())
188244

189245
assert sliced_shape.labels == ["replicas", "hosts", "gpus"]
190246
assert sliced_shape.ndslice.sizes == [2, 4, 3]
@@ -224,7 +280,8 @@ async def test_slice_actor_mesh_handle() -> None:
224280
@run_on_tokio
225281
async def run() -> None:
226282
proc_mesh = await allocate()
227-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
283+
actor_mesh = await spawn_actor_mesh(proc_mesh)
284+
228285
await verify_slice(actor_mesh, proc_mesh.client)
229286

230287
await proc_mesh.stop_nonblocking()
@@ -239,7 +296,8 @@ async def test_slice_actor_mesh_ref() -> None:
239296
@run_on_tokio
240297
async def run() -> None:
241298
proc_mesh = await allocate()
242-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
299+
actor_mesh = await spawn_actor_mesh(proc_mesh)
300+
243301
actor_mesh_ref = actor_mesh.bind()
244302
await verify_slice(actor_mesh_ref, proc_mesh.client)
245303

0 commit comments

Comments
 (0)