Skip to content

Commit a34a91e

Browse files
pzhan9facebook-github-bot
authored andcommitted
Refactoring: move verify_casting out so it can be re-used by other tests (#1303)
Summary: Pull Request resolved: #1303 Two refactorings to make the next diff D82537988 looks less busy: 1. extract `cast_v0`; 2. extract `verify_casting`. Reviewed By: mariusae Differential Revision: D83001963 fbshipit-source-id: 33cb027b73e26d8b1606ff4eb5baa13454d2de6a
1 parent 167c00b commit a34a91e

File tree

2 files changed

+70
-70
lines changed

2 files changed

+70
-70
lines changed

hyperactor_mesh/src/v1/actor_mesh.rs

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use serde::Deserializer;
3333
use serde::Serialize;
3434
use serde::Serializer;
3535

36+
use crate::CommActor;
3637
use crate::actor_mesh as v0_actor_mesh;
3738
use crate::comm::multicast;
3839
use crate::proc_mesh::mesh_agent::ActorState;
@@ -136,33 +137,7 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
136137
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
137138
{
138139
if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
139-
let cast_mesh_shape = view::Ranked::region(self).into();
140-
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
141-
match &self.proc_mesh.root_region {
142-
Some(root_region) => {
143-
let root_mesh_shape = root_region.into();
144-
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
145-
cx,
146-
actor_mesh_id,
147-
root_comm_actor,
148-
&sel!(*),
149-
message,
150-
&cast_mesh_shape,
151-
&root_mesh_shape,
152-
)
153-
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
154-
}
155-
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
156-
cx,
157-
actor_mesh_id,
158-
root_comm_actor,
159-
sel!(*),
160-
&cast_mesh_shape,
161-
&cast_mesh_shape,
162-
message,
163-
)
164-
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
165-
}
140+
self.cast_v0(cx, message, root_comm_actor)
166141
} else {
167142
for (point, actor) in self.iter() {
168143
let mut headers = Attrs::new();
@@ -180,6 +155,45 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
180155
}
181156
}
182157

158+
fn cast_v0<M>(
159+
&self,
160+
cx: &impl context::Actor,
161+
message: M,
162+
root_comm_actor: &ActorRef<CommActor>,
163+
) -> v1::Result<()>
164+
where
165+
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
166+
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
167+
{
168+
let cast_mesh_shape = view::Ranked::region(self).into();
169+
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
170+
match &self.proc_mesh.root_region {
171+
Some(root_region) => {
172+
let root_mesh_shape = root_region.into();
173+
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
174+
cx,
175+
actor_mesh_id,
176+
root_comm_actor,
177+
&sel!(*),
178+
message,
179+
&cast_mesh_shape,
180+
&root_mesh_shape,
181+
)
182+
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
183+
}
184+
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
185+
cx,
186+
actor_mesh_id,
187+
root_comm_actor,
188+
sel!(*),
189+
&cast_mesh_shape,
190+
&cast_mesh_shape,
191+
message,
192+
)
193+
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
194+
}
195+
}
196+
183197
pub async fn actor_states(
184198
&self,
185199
cx: &impl context::Actor,

hyperactor_mesh/src/v1/testactor.rs

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ use hyperactor::mailbox;
3838
use hyperactor::supervision::ActorSupervisionEvent;
3939
use ndslice::Point;
4040
#[cfg(test)]
41-
use ndslice::ViewExt;
41+
use ndslice::ViewExt as _;
4242
use serde::Deserialize;
4343
use serde::Serialize;
4444

4545
use crate::comm::multicast::CastInfo;
4646
#[cfg(test)]
4747
use crate::v1::ActorMesh;
4848
#[cfg(test)]
49+
use crate::v1::ActorMeshRef;
50+
#[cfg(test)]
4951
use crate::v1::testing;
5052

5153
/// A simple test actor used by various unit tests.
@@ -218,28 +220,7 @@ impl Handler<GetCastInfo> for TestActor {
218220
pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
219221
let instance = testing::instance().await;
220222
// Verify casting to the root actor mesh
221-
{
222-
let (port, mut rx) = mailbox::open_port(&instance);
223-
actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
224-
225-
let mut expected_actor_ids: HashSet<_> = actor_mesh
226-
.values()
227-
.map(|actor_ref| actor_ref.actor_id().clone())
228-
.collect();
229-
230-
while !expected_actor_ids.is_empty() {
231-
let actor_id = rx.recv().await.unwrap();
232-
assert!(
233-
expected_actor_ids.remove(&actor_id),
234-
"got {actor_id}, expect {expected_actor_ids:?}"
235-
);
236-
}
237-
238-
// No more messages
239-
RealClock.sleep(Duration::from_secs(1)).await;
240-
let result = rx.try_recv();
241-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
242-
}
223+
assert_casting_correctness(&actor_mesh, instance).await;
243224

244225
// Just pick the first dimension. Slice half of it off.
245226
// actor_mesh.extent().
@@ -248,28 +229,33 @@ pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
248229

249230
// Verify casting to the sliced actor mesh
250231
let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
251-
{
252-
let (port, mut rx) = mailbox::open_port(instance);
253-
sliced_actor_mesh
254-
.cast(instance, GetActorId(port.bind()))
255-
.unwrap();
232+
assert_casting_correctness(&sliced_actor_mesh, instance).await;
233+
}
256234

257-
let mut expected_actor_ids: HashSet<_> = sliced_actor_mesh
258-
.values()
259-
.map(|actor_ref| actor_ref.actor_id().clone())
260-
.collect();
235+
#[cfg(test)]
236+
/// Cast to the actor mesh, and verify that all actors are reached.
237+
pub async fn assert_casting_correctness(
238+
actor_mesh: &ActorMeshRef<TestActor>,
239+
instance: &Instance<()>,
240+
) {
241+
let (port, mut rx) = mailbox::open_port(instance);
242+
actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
261243

262-
while !expected_actor_ids.is_empty() {
263-
let actor_id = rx.recv().await.unwrap();
264-
assert!(
265-
expected_actor_ids.remove(&actor_id),
266-
"got {actor_id}, expect {expected_actor_ids:?}"
267-
);
268-
}
244+
let mut expected_actor_ids: HashSet<_> = actor_mesh
245+
.values()
246+
.map(|actor_ref| actor_ref.actor_id().clone())
247+
.collect();
269248

270-
// No more messages
271-
RealClock.sleep(Duration::from_secs(1)).await;
272-
let result = rx.try_recv();
273-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
249+
while !expected_actor_ids.is_empty() {
250+
let actor_id = rx.recv().await.unwrap();
251+
assert!(
252+
expected_actor_ids.remove(&actor_id),
253+
"got {actor_id}, expect {expected_actor_ids:?}"
254+
);
274255
}
256+
257+
// No more messages
258+
RealClock.sleep(Duration::from_secs(1)).await;
259+
let result = rx.try_recv();
260+
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
275261
}

0 commit comments

Comments
 (0)