diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index d9126424d..7db67b3b6 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -134,9 +134,12 @@ pub enum ProcState { /// Reference to this proc's mesh agent. In the future, we'll reserve a /// 'well known' PID (0) for this purpose. mesh_agent: ActorRef, - /// The address of this proc. The endpoint of this address is + /// The address of this proc which may be the true address or the address of a + /// forwarding proxy. The endpoint of this address is /// the proc's mailbox, which accepts [`hyperactor::mailbox::MessageEnvelope`]s. addr: ChannelAddr, + /// The true address of this proc to be used for direct peer communication + local_addr: ChannelAddr, }, /// A proc was stopped. Stopped { @@ -283,6 +286,7 @@ pub(crate) struct AllocatedProc { pub create_key: ShortUuid, pub proc_id: ProcId, pub addr: ChannelAddr, + pub local_addr: ChannelAddr, pub mesh_agent: ActorRef, } @@ -290,8 +294,8 @@ impl fmt::Display for AllocatedProc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, mesh_agent: {} }}", - self.create_key, self.proc_id, self.addr, self.mesh_agent + "AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, local_addr: {}, mesh_agent: {} }}", + self.create_key, self.proc_id, self.addr, self.local_addr, self.mesh_agent ) } } @@ -344,6 +348,7 @@ impl AllocExt for A { proc_id, mesh_agent, addr, + local_addr, } => { let Some(rank) = created.rank(&create_key) else { tracing::warn!( @@ -358,6 +363,7 @@ impl AllocExt for A { create_key, proc_id: proc_id.clone(), addr: addr.clone(), + local_addr, mesh_agent: mesh_agent.clone(), }; if let Some(old_allocated_proc) = running.insert(*rank, allocated_proc.clone()) @@ -744,6 +750,7 @@ pub(crate) mod testing { proc_id, mesh_agent, addr, + .. } => { router.bind(Reference::Proc(proc_id.clone()), addr.clone()); diff --git a/hyperactor_mesh/src/alloc/local.rs b/hyperactor_mesh/src/alloc/local.rs index 039442cae..3e3a267c2 100644 --- a/hyperactor_mesh/src/alloc/local.rs +++ b/hyperactor_mesh/src/alloc/local.rs @@ -215,7 +215,8 @@ impl Alloc for LocalAlloc { create_key, proc_id, mesh_agent: mesh_agent.bind(), - addr, + addr: addr.clone(), + local_addr: addr, }); break Some(created); } diff --git a/hyperactor_mesh/src/alloc/process.rs b/hyperactor_mesh/src/alloc/process.rs index 0a2bf6cde..54b014f0d 100644 --- a/hyperactor_mesh/src/alloc/process.rs +++ b/hyperactor_mesh/src/alloc/process.rs @@ -521,7 +521,8 @@ impl Alloc for ProcessAlloc { create_key: self.created[index].clone(), proc_id, mesh_agent, - addr, + addr: addr.clone(), + local_addr: addr, }); } Process2AllocatorMessage::Heartbeat => { diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index 621ed3de5..cbc4c6cc8 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -431,12 +431,12 @@ impl RemoteProcessAllocator { tracing::debug!(name = event.as_ref(), "got event: {:?}", event); let event = match event { ProcState::Created { .. } => event, - ProcState::Running { create_key, proc_id, mesh_agent, addr } => { + ProcState::Running { create_key, proc_id, mesh_agent, addr, local_addr } => { // TODO(meriksen, direct addressing): disable remapping in direct addressing mode tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr); mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone()); router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr); - ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() } + ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone(), local_addr } }, ProcState::Stopped { create_key, reason } => { match mesh_agents_by_create_key.remove(&create_key) { @@ -1272,6 +1272,7 @@ mod test { create_key, proc_id, addr: ChannelAddr::Unix("/proc0".parse().unwrap()), + local_addr: ChannelAddr::Unix("/proc0".parse().unwrap()), mesh_agent, }) }); diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 72c688219..84bfa795f 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -365,7 +365,7 @@ impl ProcMesh { // 6. Configure the mesh agents. This transmits the address book to all agents, // so that they can resolve and route traffic to all nodes in the mesh. - let address_book: HashMap<_, _> = running + let base_book: HashMap<_, _> = running .iter() .map( |AllocatedProc { @@ -374,15 +374,46 @@ impl ProcMesh { ) .collect(); + // Here addr != local_addr when the proc is behind a forwarding proxy + let local_addrs_by_proxy = running.iter().fold( + HashMap::>::new(), + |mut acc, + AllocatedProc { + addr, + local_addr, + mesh_agent, + .. + }| { + if addr != local_addr { + acc.entry(addr.clone()) + .or_default() + .push((mesh_agent.actor_id().proc_id().clone(), local_addr.clone())); + } + acc + }, + ); + let (config_handle, mut config_receiver) = client.open_port(); - for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() { + for ( + rank, + AllocatedProc { + mesh_agent, addr, .. + }, + ) in running.iter().enumerate() + { + let mut address_book = base_book.clone(); + // Overwrite addrs with local_addrs for procs that share a forwarding proxy + if let Some(local_addrs) = local_addrs_by_proxy.get(addr) { + address_book.extend(local_addrs.iter().cloned()); + } + mesh_agent .configure( &client, rank, router_channel_addr.clone(), Some(supervision_port.bind()), - address_book.clone(), + address_book, config_handle.bind(), false, ) @@ -458,6 +489,7 @@ impl ProcMesh { proc_id, addr, mesh_agent, + .. }| (create_key, proc_id, addr, mesh_agent), ) .collect(),