Skip to content

Commit a57c16e

Browse files
zdevitometa-codesync[bot]
authored andcommitted
python simple bootstrap (#1415)
Summary: Pull Request resolved: #1415 Adds the python implementation for simple bootstrapping. This lets us create host meshes out of any list of workers running the worker loop. ghstack-source-id: 314295600 exported-using-ghexport Reviewed By: colin2328 Differential Revision: D83721433 fbshipit-source-id: 52b4cc0e556e78a1bfe861ec1d3aab0838bebc94
1 parent 2898d76 commit a57c16e

File tree

8 files changed

+452
-4
lines changed

8 files changed

+452
-4
lines changed

hyperactor/src/channel.rs

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,6 @@ impl FromStr for ChannelAddr {
577577
type Err = anyhow::Error;
578578

579579
fn from_str(addr: &str) -> Result<Self, Self::Err> {
580-
// "!" is the legacy delimiter; ":" is preferred
581580
match addr.split_once('!').or_else(|| addr.split_once(':')) {
582581
Some(("local", rest)) => rest
583582
.parse::<u64>()
@@ -596,6 +595,102 @@ impl FromStr for ChannelAddr {
596595
}
597596
}
598597

598+
impl ChannelAddr {
599+
/// Parse ZMQ-style URL format: scheme://address
600+
/// Supports:
601+
/// - tcp://hostname:port or tcp://*:port (wildcard binding)
602+
/// - inproc://endpoint-name (equivalent to local)
603+
/// - ipc://path (equivalent to unix)
604+
/// - metatls://hostname:port or metatls://*:port
605+
pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
606+
// Try ZMQ-style URL format first (scheme://...)
607+
let (scheme, address) = address.split_once("://").ok_or_else(|| {
608+
anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
609+
})?;
610+
611+
match scheme {
612+
"tcp" => {
613+
let (host, port) = Self::split_host_port(address)?;
614+
615+
if host == "*" {
616+
// Wildcard binding - use IPv6 unspecified address
617+
Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
618+
} else {
619+
// Resolve hostname to IP address for proper SocketAddr creation
620+
let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
621+
Ok(Self::Tcp(socket_addr))
622+
}
623+
}
624+
"inproc" => {
625+
// inproc://port -> local:port
626+
// Port must be a valid u64 number
627+
let port = address.parse::<u64>().map_err(|_| {
628+
anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
629+
})?;
630+
Ok(Self::Local(port))
631+
}
632+
"ipc" => {
633+
// ipc://path -> unix:path
634+
Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
635+
}
636+
"metatls" => {
637+
let (host, port) = Self::split_host_port(address)?;
638+
639+
if host == "*" {
640+
// Wildcard binding - use IPv6 unspecified address directly without hostname resolution
641+
Ok(Self::MetaTls(MetaTlsAddr::Host {
642+
hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
643+
port,
644+
}))
645+
} else {
646+
Ok(Self::MetaTls(MetaTlsAddr::Host {
647+
hostname: host.to_string(),
648+
port,
649+
}))
650+
}
651+
}
652+
scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
653+
}
654+
}
655+
656+
/// Split host:port string, supporting IPv6 addresses
657+
fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
658+
if let Some((host, port_str)) = address.rsplit_once(':') {
659+
let port: u16 = port_str
660+
.parse()
661+
.map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
662+
Ok((host, port))
663+
} else {
664+
Err(anyhow::anyhow!("invalid address format: {}", address))
665+
}
666+
}
667+
668+
/// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
669+
fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
670+
// Handle IPv6 addresses in brackets by stripping the brackets
671+
let host_clean = if host.starts_with('[') && host.ends_with(']') {
672+
&host[1..host.len() - 1]
673+
} else {
674+
host
675+
};
676+
677+
// First try to parse as an IP address directly
678+
if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
679+
return Ok(SocketAddr::new(ip_addr, port));
680+
}
681+
682+
// If not an IP, try hostname resolution
683+
use std::net::ToSocketAddrs;
684+
let mut addrs = (host_clean, port)
685+
.to_socket_addrs()
686+
.map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
687+
688+
addrs
689+
.next()
690+
.ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
691+
}
692+
}
693+
599694
/// Universal channel transmitter.
600695
#[derive(Debug)]
601696
pub struct ChannelTx<M: RemoteMessage> {
@@ -832,6 +927,78 @@ mod tests {
832927
}
833928
}
834929

930+
#[test]
931+
fn test_zmq_style_channel_addr() {
932+
// Test TCP addresses
933+
assert_eq!(
934+
ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
935+
ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
936+
);
937+
938+
// Test TCP wildcard binding
939+
assert_eq!(
940+
ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
941+
ChannelAddr::Tcp("[::]:5555".parse().unwrap())
942+
);
943+
944+
// Test inproc (maps to local with numeric endpoint)
945+
assert_eq!(
946+
ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
947+
ChannelAddr::Local(12345)
948+
);
949+
950+
// Test ipc (maps to unix)
951+
assert_eq!(
952+
ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
953+
ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
954+
);
955+
956+
// Test metatls with hostname
957+
assert_eq!(
958+
ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
959+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
960+
hostname: "example.com".to_string(),
961+
port: 443
962+
})
963+
);
964+
965+
// Test metatls with IP address (should be normalized)
966+
assert_eq!(
967+
ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
968+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
969+
hostname: "192.168.1.1".to_string(),
970+
port: 443
971+
})
972+
);
973+
974+
// Test metatls with wildcard (should use IPv6 unspecified address)
975+
assert_eq!(
976+
ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
977+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
978+
hostname: "::".to_string(),
979+
port: 8443
980+
})
981+
);
982+
983+
// Test TCP hostname resolution (should resolve hostname to IP)
984+
// Note: This test may fail in environments without proper DNS resolution
985+
// We test that it at least doesn't fail to parse
986+
let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
987+
assert!(tcp_hostname_result.is_ok());
988+
989+
// Test IPv6 address
990+
assert_eq!(
991+
ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
992+
ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
993+
);
994+
995+
// Test error cases
996+
assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
997+
assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
998+
assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
999+
assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1000+
}
1001+
8351002
#[tokio::test]
8361003
async fn test_multiple_connections() {
8371004
for addr in ChannelTransport::all().map(ChannelAddr::any) {

monarch_hyperactor/src/bootstrap.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use futures::future::try_join_all;
10+
use hyperactor::channel::ChannelAddr;
11+
use hyperactor_mesh::Bootstrap;
12+
use hyperactor_mesh::bootstrap::BootstrapCommand;
913
use hyperactor_mesh::bootstrap_or_die;
14+
use hyperactor_mesh::v1::HostMeshRef;
15+
use hyperactor_mesh::v1::Name;
16+
use hyperactor_mesh::v1::host_mesh::HostMesh;
17+
use monarch_types::MapPyErr;
1018
use pyo3::Bound;
1119
use pyo3::PyAny;
1220
use pyo3::PyResult;
@@ -17,6 +25,9 @@ use pyo3::types::PyModule;
1725
use pyo3::types::PyModuleMethods;
1826
use pyo3::wrap_pyfunction;
1927

28+
use crate::pytokio::PyPythonTask;
29+
use crate::v1::host_mesh::PyHostMesh;
30+
2031
#[pyfunction]
2132
#[pyo3(signature = ())]
2233
pub fn bootstrap_main(py: Python) -> PyResult<Bound<PyAny>> {
@@ -36,6 +47,93 @@ pub fn bootstrap_main(py: Python) -> PyResult<Bound<PyAny>> {
3647
})
3748
}
3849

50+
#[pyfunction]
51+
pub fn run_worker_loop_forever(py: Python<'_>, address: &str) -> PyResult<PyPythonTask> {
52+
let addr = ChannelAddr::from_zmq_url(address)?;
53+
54+
// Check if we're running in a PAR/XAR build by looking for FB_XAR_INVOKED_NAME environment variable
55+
let invoked_name = std::env::var("FB_XAR_INVOKED_NAME");
56+
57+
let mut env: std::collections::HashMap<String, String> = std::env::vars().collect();
58+
59+
let command = Some(if let Ok(invoked_name) = invoked_name {
60+
// For PAR/XAR builds: use argv[0] from Python's sys.argv as the current executable
61+
let current_exe = std::path::PathBuf::from(&invoked_name);
62+
63+
// For PAR/XAR builds: set PAR_MAIN_OVERRIDE and no additional args
64+
env.insert(
65+
"PAR_MAIN_OVERRIDE".to_string(),
66+
"monarch._src.actor.bootstrap_main".to_string(),
67+
);
68+
BootstrapCommand {
69+
program: current_exe,
70+
arg0: Some(invoked_name),
71+
args: vec![],
72+
env,
73+
}
74+
} else {
75+
// For regular Python builds: use current_exe() and -m arguments
76+
let current_exe = std::env::current_exe().map_err(|e| {
77+
pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
78+
"Failed to get current executable: {}",
79+
e
80+
))
81+
})?;
82+
let current_exe_str = current_exe.to_string_lossy().to_string();
83+
BootstrapCommand {
84+
program: current_exe,
85+
arg0: Some(current_exe_str),
86+
args: vec![
87+
"-m".to_string(),
88+
"monarch._src.actor.bootstrap_main".to_string(),
89+
],
90+
env,
91+
}
92+
});
93+
94+
let boot = Bootstrap::Host {
95+
addr,
96+
config: None,
97+
command,
98+
};
99+
100+
PyPythonTask::new(async {
101+
let err = boot.bootstrap().await;
102+
Err(err).map_pyerr()?;
103+
Ok(())
104+
})
105+
}
106+
107+
#[pyfunction]
108+
pub fn attach_to_workers<'py>(
109+
workers: Vec<Bound<'py, PyPythonTask>>,
110+
name: Option<&str>,
111+
) -> PyResult<PyPythonTask> {
112+
let tasks = workers
113+
.into_iter()
114+
.map(|x| x.borrow_mut().take_task())
115+
.collect::<PyResult<Vec<_>>>()?;
116+
117+
let name = Name::new(name.unwrap_or("hosts"));
118+
PyPythonTask::new(async move {
119+
let results = try_join_all(tasks).await?;
120+
121+
let addresses: Result<Vec<ChannelAddr>, anyhow::Error> = Python::with_gil(|py| {
122+
results
123+
.into_iter()
124+
.map(|result| {
125+
let url_str: String = result.bind(py).extract()?;
126+
ChannelAddr::from_zmq_url(&url_str)
127+
})
128+
.collect()
129+
});
130+
let addresses = addresses?;
131+
132+
let host_mesh = HostMesh::take(name, HostMeshRef::from_hosts(addresses));
133+
Ok(PyHostMesh::new_owned(host_mesh))
134+
})
135+
}
136+
39137
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
40138
let f = wrap_pyfunction!(bootstrap_main, hyperactor_mod)?;
41139
f.setattr(
@@ -44,5 +142,19 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
44142
)?;
45143
hyperactor_mod.add_function(f)?;
46144

145+
let f = wrap_pyfunction!(run_worker_loop_forever, hyperactor_mod)?;
146+
f.setattr(
147+
"__module__",
148+
"monarch._rust_bindings.monarch_hyperactor.bootstrap",
149+
)?;
150+
hyperactor_mod.add_function(f)?;
151+
152+
let f = wrap_pyfunction!(attach_to_workers, hyperactor_mod)?;
153+
f.setattr(
154+
"__module__",
155+
"monarch._rust_bindings.monarch_hyperactor.bootstrap",
156+
)?;
157+
hyperactor_mod.add_function(f)?;
158+
47159
Ok(())
48160
}

monarch_hyperactor/src/v1/host_mesh.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ impl PyBootstrapCommand {
8787
name = "HostMesh",
8888
module = "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh"
8989
)]
90-
enum PyHostMesh {
90+
pub(crate) enum PyHostMesh {
9191
Owned(PyHostMeshImpl),
9292
Ref(PyHostMeshRefImpl),
9393
}
9494

9595
impl PyHostMesh {
96-
fn new_owned(inner: HostMesh) -> Self {
96+
pub(crate) fn new_owned(inner: HostMesh) -> Self {
9797
Self::Owned(PyHostMeshImpl(SharedCell::from(inner)))
9898
}
9999

monarch_types/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub trait MapPyErr<T> {
4545
}
4646
impl<T, E> MapPyErr<T> for Result<T, E>
4747
where
48-
E: Error,
48+
E: ToString,
4949
{
5050
fn map_pyerr(self) -> Result<T, PyErr> {
5151
self.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))

python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,17 @@
66

77
# pyre-strict
88

9+
from pathlib import Path
10+
from typing import List, Literal, Optional, Union
11+
12+
PrivateKey = Union[bytes, Path, None]
13+
CA = Union[bytes, Path, Literal["trust_all_connections"]]
14+
15+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
16+
from monarch._rust_bindings.monarch_hyperactor.v1.host_mesh import HostMesh
17+
918
def bootstrap_main() -> None: ...
19+
def run_worker_loop_forever(address: str) -> PythonTask[None]: ...
20+
def attach_to_workers(
21+
workers: List[PythonTask[str]], name: Optional[str] = None
22+
) -> PythonTask[HostMesh]: ...

0 commit comments

Comments
 (0)