4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import logging
7
8
import os
8
9
import pickle
10
+ import shlex
11
+ import signal
9
12
import subprocess
10
13
import sys
11
14
import tempfile
12
15
from abc import ABC , abstractmethod
13
- from typing import Dict , Literal , NamedTuple , Optional , Sequence
16
+ from typing import cast , Dict , List , Literal , NamedTuple , Optional , Sequence
17
+
18
+ from monarch ._rust_bindings .monarch_hyperactor .channel import ChannelTransport
19
+ from monarch ._rust_bindings .monarch_hyperactor .config import configure
20
+
21
+ from monarch ._src .actor .bootstrap import attach_to_workers
14
22
15
23
# note: the jobs api is intended as a library so it should
16
24
# only be importing _public_ monarch API functions.
17
25
from monarch ._src .actor .host_mesh import HostMesh , this_host
26
+
18
27
from typing_extensions import Self
19
28
20
29
@@ -39,6 +48,12 @@ class CachedRunning(NamedTuple):
39
48
job : "JobTrait"
40
49
41
50
51
+ logger = logging .getLogger (__name__ )
52
+ logger .setLevel (logging .INFO )
53
+ logger .addHandler (logging .StreamHandler (sys .stderr ))
54
+ logger .propagate = False
55
+
56
+
42
57
class JobTrait (ABC ):
43
58
def __init__ (self ):
44
59
super ().__init__ ()
@@ -102,6 +117,10 @@ def apply(self, client_script: Optional[str] = None):
102
117
self ._create (client_script )
103
118
self ._status = "running"
104
119
120
+ @property
121
+ def active (self ) -> bool :
122
+ return self ._running is not None
123
+
105
124
def state (self , cached_path : Optional [str ] = ".monarch/job_state.pkl" ) -> JobState :
106
125
"""
107
126
Get the current state of this job, containing the host mesh objects of its requires that were requested
@@ -124,30 +143,44 @@ def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobSta
124
143
# calls to attach_to_workers and return the HostMeshes
125
144
running_job = self ._running
126
145
if running_job is not None :
146
+ logger .info ("Job is running, returning current state" )
127
147
return running_job ._state ()
128
148
129
149
cached = self ._load_cached (cached_path )
130
150
if cached is not None :
131
151
self ._status = CachedRunning (cached )
152
+ logger .info ("Connecting to cached job" )
132
153
return cached ._state ()
154
+ logger .info ("Applying current job" )
133
155
self .apply ()
134
156
if cached_path is not None :
135
157
# Create the directory for cached_path if it doesn't exist
136
158
cache_dir = os .path .dirname (cached_path )
137
159
if cache_dir : # Only create if there's a directory component
138
160
os .makedirs (cache_dir , exist_ok = True )
161
+ logger .info ("Saving job to cache at %s" , cached_path )
139
162
self .dump (cached_path )
163
+ logger .info ("Job has started, connecting to current state" )
140
164
return self ._state ()
141
165
142
166
def _load_cached (self , cached_path : Optional [str ]) -> "Optional[JobTrait]" :
143
167
if cached_path is None :
168
+ logger .info ("No cached path provided" )
144
169
return None
145
170
try :
146
171
job = job_load (cached_path )
172
+ logger .info ("Found cached job at path: %s" , cached_path )
147
173
except FileNotFoundError :
174
+ logger .info ("No cached job found at path: %s" , cached_path )
148
175
return None
149
176
running = job ._running
150
- if running is None or not running .can_run (self ):
177
+ if running is None :
178
+ logger .info ("Cached job is not running" )
179
+ return None
180
+ if not running .can_run (self ):
181
+ logger .info ("Cached job cannot run this spec, removing cache" )
182
+ running ._kill ()
183
+ os .remove (cached_path )
151
184
return None
152
185
return job
153
186
@@ -164,6 +197,12 @@ def dumps(self) -> bytes:
164
197
# @lint-ignore PYTHONPICKLEISBAD
165
198
return pickle .dumps (self )
166
199
200
+ def kill (self ):
201
+ running = self ._running
202
+ if running is not None :
203
+ running ._kill ()
204
+ self ._status = "not_running"
205
+
167
206
@abstractmethod
168
207
def _state (self ) -> JobState : ...
169
208
@@ -181,11 +220,6 @@ def can_run(self, spec: "JobTrait") -> bool:
181
220
182
221
...
183
222
184
- def kill (self ):
185
- running = self ._running
186
- if running is not None :
187
- running ._kill ()
188
-
189
223
@abstractmethod
190
224
def _kill (self ):
191
225
"""
@@ -244,8 +278,10 @@ def _create(self, client_script: Optional[str]):
244
278
log_dir = self ._setup_log_directory ()
245
279
self ._run_client_as_daemon (client_script , log_dir )
246
280
247
- print (f"Started client script { client_script } with PID: { self .process .pid } " )
248
- print (f"Logs available at: { log_dir } " )
281
+ logger .info (
282
+ "Started client script %s with PID: %d" , client_script , self .process .pid
283
+ )
284
+ logger .info ("Logs available at: %s" , log_dir )
249
285
250
286
def _setup_log_directory (self ) -> str :
251
287
"""Create a log directory for the batch job."""
@@ -323,5 +359,150 @@ def _create(self, client_script: Optional[str] = None):
323
359
return self ._job ._create (client_script )
324
360
325
361
def _kill (self ):
326
- print ("Stopping Batch Job" )
362
+ logger . info ("Stopping Batch Job" )
327
363
return self ._job ._kill ()
364
+
365
+
366
+ class ProcessState (NamedTuple ):
367
+ pid : int
368
+ channel : str
369
+
370
+
371
+ class LoginJob (JobTrait ):
372
+ """
373
+ Makes a connections directly to hosts via an explicit list.
374
+ """
375
+
376
+ def __init__ (self ):
377
+ super ().__init__ ()
378
+ self ._meshes : Dict [str , List [str ]] = {}
379
+ self ._host_to_pid : Dict [str , ProcessState ] = {}
380
+
381
+ def add_mesh (self , name : str , hosts : List [str ]):
382
+ self ._meshes [name ] = hosts
383
+
384
+ def _state (self ) -> JobState :
385
+ if not self ._pids_active ():
386
+ raise RuntimeError ("lost connection" )
387
+ hosts = {
388
+ name : cast (
389
+ "HostMesh" ,
390
+ attach_to_workers (
391
+ name = name ,
392
+ ca = "trust_all_connections" ,
393
+ workers = [self ._host_to_pid [v ].channel for v in values ],
394
+ ),
395
+ )
396
+ for name , values in self ._meshes .items ()
397
+ }
398
+ return JobState (hosts )
399
+
400
+ def _create (self , client_script : Optional [str ]):
401
+ if client_script is not None :
402
+ raise RuntimeError ("LoginJob cannot run batch-mode scripts" )
403
+
404
+ for hosts in self ._meshes .values ():
405
+ for host in hosts :
406
+ self ._host_to_pid [host ] = self ._start_host (host )
407
+
408
+ @abstractmethod
409
+ def _start_host (self , host : str ) -> ProcessState : ...
410
+
411
+ def can_run (self , spec : "JobTrait" ) -> bool :
412
+ """
413
+ Is this job capable of running the job spec? This is used to check if a
414
+ cached job can be used to run `spec` instead of creating a new reserveration.
415
+
416
+ It is also used by the batch run infrastructure to indicate that the batch job can certainly run itself.
417
+ """
418
+ return (
419
+ isinstance (spec , LoginJob )
420
+ and spec ._meshes == self ._meshes
421
+ and self ._pids_active ()
422
+ )
423
+
424
+ def _pids_active (self ) -> bool :
425
+ if not self .active :
426
+ return False
427
+ for _ , p in self ._host_to_pid .items ():
428
+ try :
429
+ # Check if process exists by sending signal 0
430
+ os .kill (p .pid , 0 )
431
+ except OSError :
432
+ # Process doesn't exist or we don't have permission to signal it
433
+ return False
434
+ return True
435
+
436
+ def _kill (self ):
437
+ for p in self ._host_to_pid .values ():
438
+ try :
439
+ os .kill (p .pid , signal .SIGKILL )
440
+ except OSError :
441
+ pass
442
+
443
+
444
+ class FakeLocalLoginJob (LoginJob ):
445
+ """
446
+
447
+ Fake it that we are logging in by just making a local process that runs the bootstrap.
448
+ """
449
+
450
+ def __init__ (self ):
451
+ super ().__init__ ()
452
+ configure (default_transport = ChannelTransport .Tcp )
453
+
454
+ self ._next_port = 12345
455
+
456
+ def _start_host (self , host : str ) -> ProcessState :
457
+ port = self ._next_port
458
+ self ._next_port += 1
459
+
460
+ env = {** os .environ }
461
+ if "FB_XAR_INVOKED_NAME" in os .environ :
462
+ env ["PYTHONPATH" ] = ":" .join (sys .path )
463
+ addr = f"tcp://[::1]:{ port } "
464
+ bind_addr = f"tcp://[::1]:{ port } "
465
+ proc = subprocess .Popen (
466
+ [
467
+ sys .executable ,
468
+ "-c" ,
469
+ f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={ repr (bind_addr )} , ca="trust_all_connections")' ,
470
+ ],
471
+ env = env ,
472
+ start_new_session = True ,
473
+ )
474
+ return ProcessState (proc .pid , addr )
475
+
476
+
477
+ class SSHJob (LoginJob ):
478
+ def __init__ (
479
+ self ,
480
+ python_exe : str = "python" ,
481
+ ssh_args : Sequence [str ] = (),
482
+ monarch_port : int = 22222 ,
483
+ ):
484
+ configure (default_transport = ChannelTransport .Tcp )
485
+ self ._python_exe = python_exe
486
+ self ._ssh_args = ssh_args
487
+ self ._port = monarch_port
488
+ super ().__init__ ()
489
+
490
+ def _start_host (self , host : str ) -> ProcessState :
491
+ addr = f"tcp://{ host } :{ self ._port } "
492
+ startup = f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={ repr (addr )} , ca="trust_all_connections")'
493
+
494
+ command = f"{ shlex .quote (self ._python_exe )} -c { shlex .quote (startup )} "
495
+ proc = subprocess .Popen (
496
+ ["ssh" , * self ._ssh_args , host , "-n" , command ],
497
+ start_new_session = True ,
498
+ )
499
+ return ProcessState (proc .pid , addr )
500
+
501
+ def can_run (self , spec ):
502
+ return (
503
+ isinstance (spec , SSHJob )
504
+ and spec ._python_exe == self ._python_exe
505
+ and self ._port == spec ._port
506
+ and self ._ssh_args == spec ._ssh_args
507
+ and super ().can_run (spec )
508
+ )
0 commit comments