Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions hyperactor_mesh/src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ pub enum ProcState {
/// Reference to this proc's mesh agent. In the future, we'll reserve a
/// 'well known' PID (0) for this purpose.
mesh_agent: ActorRef<ProcMeshAgent>,
/// The address of this proc. The endpoint of this address is
/// The address of this proc which may be the true address or the address of a
/// forwarding proxy. The endpoint of this address is
/// the proc's mailbox, which accepts [`hyperactor::mailbox::MessageEnvelope`]s.
addr: ChannelAddr,
/// The true address of this proc to be used for direct peer communication
local_addr: ChannelAddr,
},
/// A proc was stopped.
Stopped {
Expand Down Expand Up @@ -283,15 +286,16 @@ pub(crate) struct AllocatedProc {
pub create_key: ShortUuid,
pub proc_id: ProcId,
pub addr: ChannelAddr,
pub local_addr: ChannelAddr,
pub mesh_agent: ActorRef<ProcMeshAgent>,
}

impl fmt::Display for AllocatedProc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, mesh_agent: {} }}",
self.create_key, self.proc_id, self.addr, self.mesh_agent
"AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, local_addr: {}, mesh_agent: {} }}",
self.create_key, self.proc_id, self.addr, self.local_addr, self.mesh_agent
)
}
}
Expand Down Expand Up @@ -344,6 +348,7 @@ impl<A: ?Sized + Send + Alloc> AllocExt for A {
proc_id,
mesh_agent,
addr,
local_addr,
} => {
let Some(rank) = created.rank(&create_key) else {
tracing::warn!(
Expand All @@ -358,6 +363,7 @@ impl<A: ?Sized + Send + Alloc> AllocExt for A {
create_key,
proc_id: proc_id.clone(),
addr: addr.clone(),
local_addr,
mesh_agent: mesh_agent.clone(),
};
if let Some(old_allocated_proc) = running.insert(*rank, allocated_proc.clone())
Expand Down Expand Up @@ -744,6 +750,7 @@ pub(crate) mod testing {
proc_id,
mesh_agent,
addr,
..
} => {
router.bind(Reference::Proc(proc_id.clone()), addr.clone());

Expand Down
3 changes: 2 additions & 1 deletion hyperactor_mesh/src/alloc/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ impl Alloc for LocalAlloc {
create_key,
proc_id,
mesh_agent: mesh_agent.bind(),
addr,
addr: addr.clone(),
local_addr: addr,
});
break Some(created);
}
Expand Down
3 changes: 2 additions & 1 deletion hyperactor_mesh/src/alloc/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ impl Alloc for ProcessAlloc {
create_key: self.created[index].clone(),
proc_id,
mesh_agent,
addr,
addr: addr.clone(),
local_addr: addr,
});
}
Process2AllocatorMessage::Heartbeat => {
Expand Down
5 changes: 3 additions & 2 deletions hyperactor_mesh/src/alloc/remoteprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,12 @@ impl RemoteProcessAllocator {
tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
let event = match event {
ProcState::Created { .. } => event,
ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
ProcState::Running { create_key, proc_id, mesh_agent, addr, local_addr } => {
// TODO(meriksen, direct addressing): disable remapping in direct addressing mode
tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone(), local_addr }
},
ProcState::Stopped { create_key, reason } => {
match mesh_agents_by_create_key.remove(&create_key) {
Expand Down Expand Up @@ -1272,6 +1272,7 @@ mod test {
create_key,
proc_id,
addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
local_addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
mesh_agent,
})
});
Expand Down
38 changes: 35 additions & 3 deletions hyperactor_mesh/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl ProcMesh {

// 6. Configure the mesh agents. This transmits the address book to all agents,
// so that they can resolve and route traffic to all nodes in the mesh.
let address_book: HashMap<_, _> = running
let base_book: HashMap<_, _> = running
.iter()
.map(
|AllocatedProc {
Expand All @@ -374,15 +374,46 @@ impl ProcMesh {
)
.collect();

// Here addr != local_addr when the proc is behind a forwarding proxy
let local_addrs_by_proxy = running.iter().fold(
HashMap::<ChannelAddr, Vec<(ProcId, ChannelAddr)>>::new(),
|mut acc,
AllocatedProc {
addr,
local_addr,
mesh_agent,
..
}| {
if addr != local_addr {
acc.entry(addr.clone())
.or_default()
.push((mesh_agent.actor_id().proc_id().clone(), local_addr.clone()));
}
acc
},
);

let (config_handle, mut config_receiver) = client.open_port();
for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
for (
rank,
AllocatedProc {
mesh_agent, addr, ..
},
) in running.iter().enumerate()
{
let mut address_book = base_book.clone();
// Overwrite addrs with local_addrs for procs that share a forwarding proxy
if let Some(local_addrs) = local_addrs_by_proxy.get(addr) {
address_book.extend(local_addrs.iter().cloned());
}

mesh_agent
.configure(
&client,
rank,
router_channel_addr.clone(),
Some(supervision_port.bind()),
address_book.clone(),
address_book,
config_handle.bind(),
false,
)
Expand Down Expand Up @@ -458,6 +489,7 @@ impl ProcMesh {
proc_id,
addr,
mesh_agent,
..
}| (create_key, proc_id, addr, mesh_agent),
)
.collect(),
Expand Down