Skip to content

Commit 7d6d542

Browse files
committed
SSHJob/LoginJob
Add a simple SSHJob variant that lets you establish a host mesh via directly ssh-ing into machines. This is probably too simple for someone to use in practice but it demos what is necessary to get a monarch job running. Differential Revision: [D84016804](https://our.internmc.facebook.com/intern/diff/D84016804/) ghstack-source-id: 314612883 Pull Request resolved: #1451
1 parent 8a45901 commit 7d6d542

File tree

3 files changed

+239
-10
lines changed

3 files changed

+239
-10
lines changed

python/monarch/_src/actor/bootstrap.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def run_worker_loop_forever(
7373
raise NotImplementedError("TLS security plumbing")
7474
# we maybe want to actually return the future and let you do other stuff,
7575
# not sure ...
76+
if "tcp://*" in address:
77+
raise NotImplementedError(
78+
"implementation does not get the host name right if it was specified as a wild card. We have to fix this"
79+
)
80+
7681
_run_worker_loop_forever(address).block_on()
7782

7883

@@ -104,6 +109,7 @@ def attach_to_workers(
104109

105110
if private_key is not None or ca != "trust_all_connections":
106111
raise NotImplementedError("TLS security plumbing")
112+
107113
workers_tasks = [_as_python_task(w) for w in workers]
108114
host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name)
109115
extent = Extent(["hosts"], [len(workers)])

python/monarch/_src/job/job.py

Lines changed: 191 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,26 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import os
89
import pickle
10+
import shlex
11+
import signal
912
import subprocess
1013
import sys
1114
import tempfile
1215
from abc import ABC, abstractmethod
13-
from typing import Dict, Literal, NamedTuple, Optional, Sequence
16+
from typing import cast, Dict, List, Literal, NamedTuple, Optional, Sequence
17+
18+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
19+
from monarch._rust_bindings.monarch_hyperactor.config import configure
20+
21+
from monarch._src.actor.bootstrap import attach_to_workers
1422

1523
# note: the jobs api is intended as a library so it should
1624
# only be importing _public_ monarch API functions.
1725
from monarch._src.actor.host_mesh import HostMesh, this_host
26+
1827
from typing_extensions import Self
1928

2029

@@ -39,6 +48,12 @@ class CachedRunning(NamedTuple):
3948
job: "JobTrait"
4049

4150

51+
logger = logging.getLogger(__name__)
52+
logger.setLevel(logging.INFO)
53+
logger.addHandler(logging.StreamHandler(sys.stderr))
54+
logger.propagate = False
55+
56+
4257
class JobTrait(ABC):
4358
def __init__(self):
4459
super().__init__()
@@ -102,6 +117,10 @@ def apply(self, client_script: Optional[str] = None):
102117
self._create(client_script)
103118
self._status = "running"
104119

120+
@property
121+
def active(self) -> bool:
122+
return self._running is not None
123+
105124
def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobState:
106125
"""
107126
Get the current state of this job, containing the host mesh objects of its requires that were requested
@@ -124,30 +143,44 @@ def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobSta
124143
# calls to attach_to_workers and return the HostMeshes
125144
running_job = self._running
126145
if running_job is not None:
146+
logger.info("Job is running, returning current state")
127147
return running_job._state()
128148

129149
cached = self._load_cached(cached_path)
130150
if cached is not None:
131151
self._status = CachedRunning(cached)
152+
logger.info("Connecting to cached job")
132153
return cached._state()
154+
logger.info("Applying current job")
133155
self.apply()
134156
if cached_path is not None:
135157
# Create the directory for cached_path if it doesn't exist
136158
cache_dir = os.path.dirname(cached_path)
137159
if cache_dir: # Only create if there's a directory component
138160
os.makedirs(cache_dir, exist_ok=True)
161+
logger.info("Saving job to cache at %s", cached_path)
139162
self.dump(cached_path)
163+
logger.info("Job has started, connecting to current state")
140164
return self._state()
141165

142166
def _load_cached(self, cached_path: Optional[str]) -> "Optional[JobTrait]":
143167
if cached_path is None:
168+
logger.info("No cached path provided")
144169
return None
145170
try:
146171
job = job_load(cached_path)
172+
logger.info("Found cached job at path: %s", cached_path)
147173
except FileNotFoundError:
174+
logger.info("No cached job found at path: %s", cached_path)
148175
return None
149176
running = job._running
150-
if running is None or not running.can_run(self):
177+
if running is None:
178+
logger.info("Cached job is not running")
179+
return None
180+
if not running.can_run(self):
181+
logger.info("Cached job cannot run this spec, removing cache")
182+
running._kill()
183+
os.remove(cached_path)
151184
return None
152185
return job
153186

@@ -164,6 +197,12 @@ def dumps(self) -> bytes:
164197
# @lint-ignore PYTHONPICKLEISBAD
165198
return pickle.dumps(self)
166199

200+
def kill(self):
201+
running = self._running
202+
if running is not None:
203+
running._kill()
204+
self._status = "not_running"
205+
167206
@abstractmethod
168207
def _state(self) -> JobState: ...
169208

@@ -181,11 +220,6 @@ def can_run(self, spec: "JobTrait") -> bool:
181220

182221
...
183222

184-
def kill(self):
185-
running = self._running
186-
if running is not None:
187-
running._kill()
188-
189223
@abstractmethod
190224
def _kill(self):
191225
"""
@@ -244,8 +278,10 @@ def _create(self, client_script: Optional[str]):
244278
log_dir = self._setup_log_directory()
245279
self._run_client_as_daemon(client_script, log_dir)
246280

247-
print(f"Started client script {client_script} with PID: {self.process.pid}")
248-
print(f"Logs available at: {log_dir}")
281+
logger.info(
282+
"Started client script %s with PID: %d", client_script, self.process.pid
283+
)
284+
logger.info("Logs available at: %s", log_dir)
249285

250286
def _setup_log_directory(self) -> str:
251287
"""Create a log directory for the batch job."""
@@ -323,5 +359,150 @@ def _create(self, client_script: Optional[str] = None):
323359
return self._job._create(client_script)
324360

325361
def _kill(self):
326-
print("Stopping Batch Job")
362+
logger.info("Stopping Batch Job")
327363
return self._job._kill()
364+
365+
366+
class ProcessState(NamedTuple):
367+
pid: int
368+
channel: str
369+
370+
371+
class LoginJob(JobTrait):
372+
"""
373+
Makes a connections directly to hosts via an explicit list.
374+
"""
375+
376+
def __init__(self):
377+
super().__init__()
378+
self._meshes: Dict[str, List[str]] = {}
379+
self._host_to_pid: Dict[str, ProcessState] = {}
380+
381+
def add_mesh(self, name: str, hosts: List[str]):
382+
self._meshes[name] = hosts
383+
384+
def _state(self) -> JobState:
385+
if not self._pids_active():
386+
raise RuntimeError("lost connection")
387+
hosts = {
388+
name: cast(
389+
"HostMesh",
390+
attach_to_workers(
391+
name=name,
392+
ca="trust_all_connections",
393+
workers=[self._host_to_pid[v].channel for v in values],
394+
),
395+
)
396+
for name, values in self._meshes.items()
397+
}
398+
return JobState(hosts)
399+
400+
def _create(self, client_script: Optional[str]):
401+
if client_script is not None:
402+
raise RuntimeError("LoginJob cannot run batch-mode scripts")
403+
404+
for hosts in self._meshes.values():
405+
for host in hosts:
406+
self._host_to_pid[host] = self._start_host(host)
407+
408+
@abstractmethod
409+
def _start_host(self, host: str) -> ProcessState: ...
410+
411+
def can_run(self, spec: "JobTrait") -> bool:
412+
"""
413+
Is this job capable of running the job spec? This is used to check if a
414+
cached job can be used to run `spec` instead of creating a new reserveration.
415+
416+
It is also used by the batch run infrastructure to indicate that the batch job can certainly run itself.
417+
"""
418+
return (
419+
isinstance(spec, LoginJob)
420+
and spec._meshes == self._meshes
421+
and self._pids_active()
422+
)
423+
424+
def _pids_active(self) -> bool:
425+
if not self.active:
426+
return False
427+
for _, p in self._host_to_pid.items():
428+
try:
429+
# Check if process exists by sending signal 0
430+
os.kill(p.pid, 0)
431+
except OSError:
432+
# Process doesn't exist or we don't have permission to signal it
433+
return False
434+
return True
435+
436+
def _kill(self):
437+
for p in self._host_to_pid.values():
438+
try:
439+
os.kill(p.pid, signal.SIGKILL)
440+
except OSError:
441+
pass
442+
443+
444+
class FakeLocalLoginJob(LoginJob):
445+
"""
446+
447+
Fake it that we are logging in by just making a local process that runs the bootstrap.
448+
"""
449+
450+
def __init__(self):
451+
super().__init__()
452+
configure(default_transport=ChannelTransport.Tcp)
453+
454+
self._next_port = 12345
455+
456+
def _start_host(self, host: str) -> ProcessState:
457+
port = self._next_port
458+
self._next_port += 1
459+
460+
env = {**os.environ}
461+
if "FB_XAR_INVOKED_NAME" in os.environ:
462+
env["PYTHONPATH"] = ":".join(sys.path)
463+
addr = f"tcp://[::1]:{port}"
464+
bind_addr = f"tcp://[::1]:{port}"
465+
proc = subprocess.Popen(
466+
[
467+
sys.executable,
468+
"-c",
469+
f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(bind_addr)}, ca="trust_all_connections")',
470+
],
471+
env=env,
472+
start_new_session=True,
473+
)
474+
return ProcessState(proc.pid, addr)
475+
476+
477+
class SSHJob(LoginJob):
478+
def __init__(
479+
self,
480+
python_exe: str = "python",
481+
ssh_args: Sequence[str] = (),
482+
monarch_port: int = 22222,
483+
):
484+
configure(default_transport=ChannelTransport.Tcp)
485+
self._python_exe = python_exe
486+
self._ssh_args = ssh_args
487+
self._port = monarch_port
488+
super().__init__()
489+
490+
def _start_host(self, host: str) -> ProcessState:
491+
addr = f"tcp://{host}:{self._port}"
492+
startup = f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")'
493+
494+
command = f"{shlex.quote(self._python_exe)} -c {shlex.quote(startup)}"
495+
proc = subprocess.Popen(
496+
["ssh", *self._ssh_args, host, "-n", command],
497+
start_new_session=True,
498+
)
499+
return ProcessState(proc.pid, addr)
500+
501+
def can_run(self, spec):
502+
return (
503+
isinstance(spec, SSHJob)
504+
and spec._python_exe == self._python_exe
505+
and self._port == spec._port
506+
and self._ssh_args == spec._ssh_args
507+
and super().can_run(spec)
508+
)

python/tests/test_python_actors.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
this_proc as this_proc_v1,
6363
)
6464
from monarch._src.actor.v1.proc_mesh import ProcMesh as ProcMeshV1
65+
from monarch._src.job.job import JobState, LoginJob, ProcessState
6566

6667
from monarch.actor import (
6768
Accumulator,
@@ -1721,3 +1722,44 @@ def test_simple_bootstrap():
17211722
for proc in procs:
17221723
proc.kill()
17231724
proc.wait()
1725+
1726+
1727+
class FakeLocalLoginJob(LoginJob):
1728+
"""
1729+
1730+
Fake it that we are logging in by just making a local process that runs the bootstrap.
1731+
"""
1732+
1733+
def __init__(self, dir: str):
1734+
super().__init__()
1735+
self._dir = dir
1736+
1737+
def _start_host(self, host: str) -> ProcessState:
1738+
env = {**os.environ}
1739+
if "FB_XAR_INVOKED_NAME" in os.environ:
1740+
env["PYTHONPATH"] = ":".join(sys.path)
1741+
addr = f"ipc://{self._dir}/{host}"
1742+
proc = subprocess.Popen(
1743+
[
1744+
sys.executable,
1745+
"-c",
1746+
f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")',
1747+
],
1748+
env=env,
1749+
start_new_session=True,
1750+
)
1751+
return ProcessState(proc.pid, addr)
1752+
1753+
1754+
def test_login_job():
1755+
with TemporaryDirectory() as temp_dir:
1756+
j = FakeLocalLoginJob(temp_dir)
1757+
j.add_mesh("hosts", ["fake", "hosts"])
1758+
state = j.state(cached_path=None)
1759+
1760+
hello = state.hosts.spawn_procs().spawn("hello", Hello)
1761+
r = hello.doit.call().get()
1762+
for _, v in r.items():
1763+
assert v == "hello!"
1764+
1765+
j.kill()

0 commit comments

Comments
 (0)