Skip to content

Commit ee34696

Browse files
samluryemeta-codesync[bot]
authored andcommitted
Set Alloc transport from AllocSpec (#1389)
Summary: Pull Request resolved: #1389 All `Alloc` implementations now get their transport type from the new transport field on `AllocSpec`. When `PyAllocSpec` is converted to `AllocSpec`, it will use whatever the current default transport is, provided via the new `hyperactor_mesh::proc_mesh::default_transport` function, which reads the config var `HYPERACTOR_MESH_DEFAULT_TRANSPORT` (renamed from `HYPERACTOR_MESH_ROOT_CLIENT_TRANSPORT`). That way, if a user creates a `PyAllocSpec`, then reconfigures the default transport, then passes the old `PyAllocSpec` to an allocator, they won't get an unexpected result. `RemoteProcessAllocator` and `MastAllocator` have slightly special handling for now so that we don't break existing v0 examples/notebooks/workflows: - `MastAllocator` can be initialized with a transport override for testing purposes, which overrides the transport value from the `AllocSpec` passed to it. - We eventually want `PyMastAllocatorBase` to throw an exception if the default transport isn't metaTLS, but if we did that now a bunch of v0 examples would break unnecessarily. So instead, for now, if the default transport is wrong, the user is warned and the transport is updated on the `AllocSpec` for them. - Similarly for `RemoteAllocator`, we eventually want to throw an exception if the transport on the addresses returned by the initializer isn't equal to the default transport. But to prevent existing v0 examples from breaking, we now simply warn the user and update the transport on the `AllocSpec` to the expected value. ghstack-source-id: 314103746 Reviewed By: mariusae Differential Revision: D83628852 fbshipit-source-id: 6f3bb683c22677d17250698804716ae36f8997a1
1 parent c56b31e commit ee34696

File tree

27 files changed

+189
-117
lines changed

27 files changed

+189
-117
lines changed

hyperactor_mesh/benches/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use criterion::Criterion;
1313
use criterion::Throughput;
1414
use criterion::criterion_group;
1515
use criterion::criterion_main;
16+
use hyperactor::channel::ChannelTransport;
1617
use hyperactor_mesh::ProcMesh;
1718
use hyperactor_mesh::actor_mesh::ActorMesh;
1819
use hyperactor_mesh::actor_mesh::RootActorMesh;
@@ -45,6 +46,7 @@ fn bench_actor_scaling(c: &mut Criterion) {
4546
extent: extent!(hosts = host_count, gpus = gpus),
4647
constraints: Default::default(),
4748
proc_name: None,
49+
transport: ChannelTransport::Local,
4850
})
4951
.await
5052
.unwrap();
@@ -142,6 +144,7 @@ fn bench_actor_mesh_message_sizes(c: &mut Criterion) {
142144
extent: extent!(gpus = actor_count),
143145
constraints: Default::default(),
144146
proc_name: None,
147+
transport: ChannelTransport::Local,
145148
})
146149
.await
147150
.unwrap();

hyperactor_mesh/examples/dining_philosophers.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use hyperactor::Instance;
2121
use hyperactor::Named;
2222
use hyperactor::PortRef;
2323
use hyperactor::Unbind;
24+
use hyperactor::channel::ChannelTransport;
2425
use hyperactor_mesh::ProcMesh;
2526
use hyperactor_mesh::actor_mesh::ActorMesh;
2627
use hyperactor_mesh::alloc::AllocSpec;
@@ -232,6 +233,7 @@ async fn main() -> Result<ExitCode> {
232233
extent: extent! {replica = group_size},
233234
constraints: Default::default(),
234235
proc_name: None,
236+
transport: ChannelTransport::Local,
235237
})
236238
.await?;
237239

hyperactor_mesh/examples/sieve.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use hyperactor::Context;
2222
use hyperactor::Handler;
2323
use hyperactor::Named;
2424
use hyperactor::PortRef;
25+
use hyperactor::channel::ChannelTransport;
2526
use hyperactor_mesh::Mesh;
2627
use hyperactor_mesh::ProcMesh;
2728
use hyperactor_mesh::alloc::AllocSpec;
@@ -109,6 +110,7 @@ async fn main() -> Result<ExitCode> {
109110
extent: extent! { replica = 1 },
110111
constraints: Default::default(),
111112
proc_name: None,
113+
transport: ChannelTransport::Local,
112114
})
113115
.await?;
114116

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,8 @@ pub(crate) mod test_util {
646646
// The actor creates a mesh.
647647
use std::sync::Arc;
648648

649+
use hyperactor::channel::ChannelTransport;
650+
649651
use crate::alloc::AllocSpec;
650652
use crate::alloc::Allocator;
651653
use crate::alloc::LocalAllocator;
@@ -656,6 +658,7 @@ pub(crate) mod test_util {
656658
extent: extent! { replica = 1 },
657659
constraints: Default::default(),
658660
proc_name: None,
661+
transport: ChannelTransport::Local,
659662
})
660663
.await
661664
.unwrap();
@@ -741,6 +744,7 @@ mod tests {
741744
use $crate::proc_mesh::SharedSpawnable;
742745
use std::collections::VecDeque;
743746
use hyperactor::data::Serialized;
747+
use $crate::proc_mesh::default_transport;
744748

745749
use super::*;
746750
use super::test_util::*;
@@ -758,6 +762,7 @@ mod tests {
758762
extent: extent! { replica = 1 },
759763
constraints: Default::default(),
760764
proc_name: None,
765+
transport: default_transport()
761766
})
762767
.await
763768
.unwrap();
@@ -781,6 +786,7 @@ mod tests {
781786
extent: extent!(replica = 4),
782787
constraints: Default::default(),
783788
proc_name: None,
789+
transport: default_transport()
784790
})
785791
.await
786792
.unwrap();
@@ -807,6 +813,7 @@ mod tests {
807813
extent: extent!(replica = 2),
808814
constraints: Default::default(),
809815
proc_name: None,
816+
transport: default_transport(),
810817
})
811818
.await
812819
.unwrap();
@@ -843,6 +850,7 @@ mod tests {
843850
extent: extent!(x = X, y = Y, z = Z),
844851
constraints: Default::default(),
845852
proc_name: None,
853+
transport: default_transport(),
846854
})
847855
.await
848856
.unwrap();
@@ -887,6 +895,7 @@ mod tests {
887895
extent: extent!(replica = 2, host = 2, gpu = 8),
888896
constraints: Default::default(),
889897
proc_name: None,
898+
transport: default_transport(),
890899
})
891900
.await
892901
.unwrap();
@@ -929,6 +938,7 @@ mod tests {
929938
extent: extent!(replica = 2, host = 2, gpu = 8),
930939
constraints: Default::default(),
931940
proc_name: None,
941+
transport: default_transport(),
932942
})
933943
.await
934944
.unwrap();
@@ -960,6 +970,7 @@ mod tests {
960970
extent: extent!(replica = 1),
961971
constraints: Default::default(),
962972
proc_name: None,
973+
transport: default_transport(),
963974
})
964975
.await
965976
.unwrap();
@@ -1008,6 +1019,7 @@ mod tests {
10081019
extent,
10091020
constraints: Default::default(),
10101021
proc_name: None,
1022+
transport: default_transport(),
10111023
})
10121024
.await
10131025
.unwrap();
@@ -1038,6 +1050,7 @@ mod tests {
10381050
extent: extent!(replica = 1 ),
10391051
constraints: Default::default(),
10401052
proc_name: None,
1053+
transport: default_transport(),
10411054
})
10421055
.await
10431056
.unwrap();
@@ -1068,6 +1081,7 @@ mod tests {
10681081
extent: extent.clone(),
10691082
constraints: Default::default(),
10701083
proc_name: None,
1084+
transport: default_transport(),
10711085
})
10721086
.await
10731087
.unwrap();
@@ -1104,6 +1118,8 @@ mod tests {
11041118
}
11051119

11061120
mod local {
1121+
use hyperactor::channel::ChannelTransport;
1122+
11071123
use crate::alloc::local::LocalAllocator;
11081124

11091125
actor_mesh_test_suite!(LocalAllocator);
@@ -1130,6 +1146,7 @@ mod tests {
11301146
extent: extent!(replica = 2),
11311147
constraints: Default::default(),
11321148
proc_name: None,
1149+
transport: ChannelTransport::Local,
11331150
})
11341151
.await
11351152
.unwrap();
@@ -1198,6 +1215,7 @@ mod tests {
11981215
extent: extent!(replica = 1),
11991216
constraints: Default::default(),
12001217
proc_name: None,
1218+
transport: ChannelTransport::Local,
12011219
})
12021220
.await
12031221
.unwrap();
@@ -1265,6 +1283,7 @@ mod tests {
12651283
extent: extent!(replica = 2),
12661284
constraints: Default::default(),
12671285
proc_name: None,
1286+
transport: ChannelTransport::Local,
12681287
})
12691288
.await
12701289
.unwrap();
@@ -1316,6 +1335,7 @@ mod tests {
13161335

13171336
use bytes::Bytes;
13181337
use hyperactor::PortId;
1338+
use hyperactor::channel::ChannelTransport;
13191339
use hyperactor::clock::Clock;
13201340
use hyperactor::clock::RealClock;
13211341
use hyperactor::mailbox::MessageEnvelope;
@@ -1375,6 +1395,7 @@ mod tests {
13751395
extent: extent!(replica = 1),
13761396
constraints: Default::default(),
13771397
proc_name: None,
1398+
transport: ChannelTransport::Unix,
13781399
})
13791400
.await
13801401
.unwrap();
@@ -1458,6 +1479,7 @@ mod tests {
14581479
extent: extent! { replica = 1 },
14591480
constraints: Default::default(),
14601481
proc_name: None,
1482+
transport: ChannelTransport::Unix,
14611483
})
14621484
.await
14631485
.unwrap();
@@ -1620,6 +1642,7 @@ mod tests {
16201642
extent,
16211643
constraints: Default::default(),
16221644
proc_name: None,
1645+
transport: ChannelTransport::Local
16231646
}))
16241647
.unwrap();
16251648
let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
@@ -1651,6 +1674,7 @@ mod tests {
16511674
extent: extent.clone(),
16521675
constraints: Default::default(),
16531676
proc_name: None,
1677+
transport: ChannelTransport::Local
16541678
}))
16551679
.unwrap();
16561680
let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
@@ -1723,6 +1747,7 @@ mod tests {
17231747
extent,
17241748
constraints: Default::default(),
17251749
proc_name: None,
1750+
transport: ChannelTransport::Local
17261751
}))
17271752
.unwrap();
17281753
let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();

hyperactor_mesh/src/alloc.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ pub struct AllocSpec {
8383
/// If specified, return procs using direct addressing with
8484
/// the provided proc name.
8585
pub proc_name: Option<String>,
86+
87+
/// The transport to use for the procs in this alloc.
88+
pub transport: ChannelTransport,
8689
}
8790

8891
/// The core allocator trait, implemented by all allocators.
@@ -249,7 +252,9 @@ pub trait Alloc {
249252
fn world_id(&self) -> &WorldId;
250253

251254
/// The channel transport used the procs in this alloc.
252-
fn transport(&self) -> ChannelTransport;
255+
fn transport(&self) -> ChannelTransport {
256+
self.spec().transport.clone()
257+
}
253258

254259
/// Stop this alloc, shutting down all of its procs. A clean
255260
/// shutdown should result in Stop events from all allocs,
@@ -513,10 +518,6 @@ pub mod test_utils {
513518
self.alloc.world_id()
514519
}
515520

516-
fn transport(&self) -> ChannelTransport {
517-
self.alloc.transport()
518-
}
519-
520521
async fn stop(&mut self) -> Result<(), AllocatorError> {
521522
self.alloc.stop().await
522523
}
@@ -548,6 +549,7 @@ pub(crate) mod testing {
548549
use super::*;
549550
use crate::alloc::test_utils::TestActor;
550551
use crate::alloc::test_utils::Wait;
552+
use crate::proc_mesh::default_transport;
551553
use crate::proc_mesh::mesh_agent::GspawnResult;
552554
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
553555

@@ -568,6 +570,7 @@ pub(crate) mod testing {
568570
extent: extent.clone(),
569571
constraints: Default::default(),
570572
proc_name: None,
573+
transport: default_transport(),
571574
})
572575
.await
573576
.unwrap();
@@ -718,6 +721,7 @@ pub(crate) mod testing {
718721
extent: extent! { replica = 1 },
719722
constraints: Default::default(),
720723
proc_name: None,
724+
transport: ChannelTransport::Unix,
721725
})
722726
.await
723727
.unwrap();

hyperactor_mesh/src/alloc/local.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use hyperactor::ProcId;
1919
use hyperactor::WorldId;
2020
use hyperactor::channel;
2121
use hyperactor::channel::ChannelAddr;
22-
use hyperactor::channel::ChannelTransport;
2322
use hyperactor::mailbox::MailboxServer;
2423
use hyperactor::mailbox::MailboxServerHandle;
2524
use hyperactor::proc::Proc;
@@ -76,15 +75,10 @@ pub struct LocalAlloc {
7675
todo_rx: mpsc::UnboundedReceiver<Action>,
7776
stopped: bool,
7877
failed: bool,
79-
transport: ChannelTransport,
8078
}
8179

8280
impl LocalAlloc {
83-
fn new(spec: AllocSpec) -> Self {
84-
Self::new_with_transport(spec, ChannelTransport::Local)
85-
}
86-
87-
pub(crate) fn new_with_transport(spec: AllocSpec, transport: ChannelTransport) -> Self {
81+
pub(crate) fn new(spec: AllocSpec) -> Self {
8882
let name = ShortUuid::generate();
8983
let (todo_tx, todo_rx) = mpsc::unbounded_channel();
9084
for rank in 0..spec.extent.num_ranks() {
@@ -100,7 +94,6 @@ impl LocalAlloc {
10094
todo_rx,
10195
stopped: false,
10296
failed: false,
103-
transport,
10497
}
10598
}
10699

@@ -265,10 +258,6 @@ impl Alloc for LocalAlloc {
265258
&self.world_id
266259
}
267260

268-
fn transport(&self) -> ChannelTransport {
269-
self.transport.clone()
270-
}
271-
272261
async fn stop(&mut self) -> Result<(), AllocatorError> {
273262
for rank in 0..self.size() {
274263
self.todo_tx

hyperactor_mesh/src/alloc/process.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ impl Allocator for ProcessAllocator {
9292
let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
9393
.map_err(anyhow::Error::from)?;
9494

95+
if spec.transport == ChannelTransport::Local {
96+
return Err(AllocatorError::Other(anyhow::anyhow!(
97+
"ProcessAllocator does not support local transport"
98+
)));
99+
}
100+
95101
let name = ShortUuid::generate();
96102
Ok(ProcessAlloc {
97103
name: name.clone(),
@@ -566,10 +572,6 @@ impl Alloc for ProcessAlloc {
566572
&self.world_id
567573
}
568574

569-
fn transport(&self) -> ChannelTransport {
570-
ChannelTransport::Unix
571-
}
572-
573575
async fn stop(&mut self) -> Result<(), AllocatorError> {
574576
// We rely on the teardown here, and that the process should
575577
// exit on its own. We should have a hard timeout here as well,
@@ -614,6 +616,7 @@ mod tests {
614616
extent: ndslice::extent!(replica = 1),
615617
constraints: Default::default(),
616618
proc_name: None,
619+
transport: ChannelTransport::Unix,
617620
})
618621
.await
619622
.unwrap();

0 commit comments

Comments
 (0)