Skip to content

Commit ac4cf80

Browse files
amirafzalifacebook-github-bot
authored andcommitted
Add SlurmJob (meta-pytorch#1470)
Summary: - Adding SlurmJob variation of JobTrait, to launch and manage meshes through the OSS SLURM scheduler Reviewed By: zdevito Differential Revision: D84191209
1 parent ef62460 commit ac4cf80

File tree

2 files changed

+342
-2
lines changed

2 files changed

+342
-2
lines changed

python/monarch/_src/job/slurm.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import json
10+
import logging
11+
import os
12+
import subprocess
13+
import sys
14+
from typing import Any, cast, Dict, FrozenSet, List, Optional, Sequence
15+
16+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
17+
from monarch._rust_bindings.monarch_hyperactor.config import configure
18+
19+
from monarch._src.actor.bootstrap import attach_to_workers
20+
from monarch._src.actor.host_mesh import HostMesh
21+
from monarch._src.job.job import JobState, JobTrait
22+
23+
24+
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.INFO)
26+
logger.addHandler(logging.StreamHandler(sys.stderr))
27+
logger.propagate = False
28+
29+
# terminal states that indicate the job is no longer active
30+
_SLURM_TERMINAL_STATES: FrozenSet[str] = frozenset(
31+
["FAILED", "CANCELLED", "TIMEOUT", "PREEMPTED", "COMPLETED"]
32+
)
33+
34+
35+
class SlurmJob(JobTrait):
36+
"""
37+
A job scheduler that uses SLURM command line tools to schedule jobs.
38+
39+
This implementation:
40+
1. Uses sbatch to submit SLURM jobs that start monarch workers
41+
2. Queries job status with squeue to get allocated hostnames
42+
3. Uses the hostnames to connect to the started workers
43+
"""
44+
45+
def __init__(
46+
self,
47+
meshes: Dict[str, int],
48+
python_exe: str = "python",
49+
slurm_args: Sequence[str] = (),
50+
monarch_port: int = 22222,
51+
job_name: str = "monarch_job",
52+
ntasks_per_node: int = 1,
53+
time_limit: Optional[str] = "12:00:00",
54+
partition: Optional[str] = None,
55+
log_dir: Optional[str] = None,
56+
exclusive: bool = True,
57+
) -> None:
58+
"""
59+
Args:
60+
meshes: Dictionary mapping mesh names to number of nodes
61+
python_exe: Python executable to use for worker processes
62+
slurm_args: Additional SLURM arguments to pass to sbatch
63+
monarch_port: Port for TCP communication between workers
64+
job_name: Name for the SLURM job
65+
ntasks_per_node: Number of tasks per node
66+
time_limit: Maximum runtime in HH:MM:SS format. If None, uses SLURM's default time limit.
67+
partition: SLURM partition to submit to
68+
log_dir: Directory for SLURM log files
69+
exclusive: Whether to request exclusive node access (no other jobs can run on the nodes).
70+
Defaults to True for predictable performance and resource isolation,
71+
but may increase queue times and waste resources if nodes are underutilized.
72+
"""
73+
configure(default_transport=ChannelTransport.Tcp)
74+
self._meshes = meshes
75+
self._python_exe = python_exe
76+
self._slurm_args = slurm_args
77+
self._port = monarch_port
78+
self._job_name = job_name
79+
self._ntasks_per_node = ntasks_per_node
80+
self._time_limit = time_limit
81+
self._partition = partition
82+
self._log_dir = log_dir if log_dir is not None else os.getcwd()
83+
self._exclusive = exclusive
84+
# Track the single SLURM job ID and all allocated hostnames
85+
self._slurm_job_id: Optional[str] = None
86+
self._all_hostnames: List[str] = []
87+
super().__init__()
88+
89+
def add_mesh(self, name: str, num_nodes: int) -> None:
90+
self._meshes[name] = num_nodes
91+
92+
def _create(self, client_script: Optional[str]) -> None:
93+
"""Submit a single SLURM job for all meshes."""
94+
if client_script is not None:
95+
raise RuntimeError("SlurmJob cannot run batch-mode scripts")
96+
97+
total_nodes = sum(self._meshes.values())
98+
self._slurm_job_id = self._submit_slurm_job(total_nodes)
99+
100+
def _submit_slurm_job(self, num_nodes: int) -> str:
101+
"""Submit a SLURM job for all nodes."""
102+
unique_job_name = f"{self._job_name}_{os.getpid()}"
103+
104+
# Create log directory if it doesn't exist
105+
os.makedirs(self._log_dir, exist_ok=True)
106+
107+
log_path_out = os.path.join(self._log_dir, f"slurm_%j_{unique_job_name}.out")
108+
log_path_err = os.path.join(self._log_dir, f"slurm_%j_{unique_job_name}.err")
109+
110+
python_command = f'import socket; from monarch.actor import run_worker_loop_forever; hostname = socket.gethostname(); run_worker_loop_forever(address=f"tcp://{{hostname}}:{self._port}", ca="trust_all_connections")'
111+
112+
# Build SBATCH directives
113+
sbatch_directives = [
114+
"#!/bin/bash",
115+
f"#SBATCH --job-name={unique_job_name}",
116+
f"#SBATCH --ntasks-per-node={self._ntasks_per_node}",
117+
f"#SBATCH --nodes={num_nodes}",
118+
f"#SBATCH --output={log_path_out}",
119+
f"#SBATCH --error={log_path_err}",
120+
]
121+
122+
if self._time_limit is not None:
123+
sbatch_directives.append(f"#SBATCH --time={self._time_limit}")
124+
125+
if self._exclusive:
126+
sbatch_directives.append("#SBATCH --exclusive")
127+
128+
if self._partition:
129+
sbatch_directives.append(f"#SBATCH --partition={self._partition}")
130+
131+
# Add any additional slurm args as directives
132+
for arg in self._slurm_args:
133+
if arg.startswith("-"):
134+
sbatch_directives.append(f"#SBATCH {arg}")
135+
136+
batch_script = "\n".join(sbatch_directives)
137+
batch_script += f"\nsrun {self._python_exe} -c '{python_command}'\n"
138+
139+
logger.info(f"Submitting SLURM job with {num_nodes} nodes")
140+
141+
try:
142+
result = subprocess.run(
143+
["sbatch"],
144+
input=batch_script,
145+
capture_output=True,
146+
text=True,
147+
check=True,
148+
)
149+
150+
# Parse the job ID from sbatch output (typically "Submitted batch job 12345")
151+
job_id = None
152+
for line in result.stdout.strip().split("\n"):
153+
if "Submitted batch job" in line:
154+
job_id = line.split()[-1]
155+
break
156+
157+
if not job_id:
158+
raise RuntimeError(
159+
f"Failed to parse job ID from sbatch output: {result.stdout}"
160+
)
161+
162+
logger.info(
163+
f"SLURM job {job_id} submitted. Logs will be written to: {self._log_dir}/slurm_{job_id}_{unique_job_name}.out"
164+
)
165+
return job_id
166+
167+
except subprocess.CalledProcessError as e:
168+
raise RuntimeError(f"Failed to submit SLURM job: {e.stderr}") from e
169+
170+
def _get_job_info_json(self, job_id: str) -> Optional[Dict[str, Any]]:
171+
"""Get job information using squeue --json."""
172+
try:
173+
result = subprocess.run(
174+
["squeue", "--job", job_id, "--json"],
175+
capture_output=True,
176+
text=True,
177+
check=True,
178+
)
179+
180+
if result.stdout.strip():
181+
data = json.loads(result.stdout)
182+
jobs = data.get("jobs", [])
183+
return jobs[0] if jobs else None
184+
return None
185+
186+
except subprocess.CalledProcessError as e:
187+
logger.warning(f"Error checking job {job_id} status: {e.stderr}")
188+
return None
189+
except (json.JSONDecodeError, KeyError) as e:
190+
logger.warning(f"Error parsing JSON response for job {job_id}: {e}")
191+
return None
192+
193+
def _wait_for_job_start(
194+
self, job_id: str, expected_nodes: int, timeout: int = 300
195+
) -> List[str]:
196+
"""
197+
Wait for the SLURM job to start and return the allocated hostnames.
198+
Requires Slurm 20.02+ for squeue --json support.
199+
"""
200+
import time
201+
202+
start_time = time.time()
203+
204+
try:
205+
while time.time() - start_time < timeout:
206+
job_info = self._get_job_info_json(job_id)
207+
208+
if not job_info:
209+
raise RuntimeError(f"SLURM job {job_id} not found in queue")
210+
211+
job_state = job_info.get("job_state", [])
212+
213+
if "RUNNING" in job_state:
214+
# Extract hostnames from job_resources.nodes.allocation
215+
job_resources = job_info.get("job_resources", {})
216+
nodes_info = job_resources.get("nodes", {})
217+
allocation = nodes_info.get("allocation", [])
218+
hostnames = [node["name"] for node in allocation]
219+
220+
logger.info(
221+
f"SLURM job {job_id} is running on {len(hostnames)} nodes: {hostnames}"
222+
)
223+
224+
if len(hostnames) != expected_nodes:
225+
raise RuntimeError(
226+
f"Expected {expected_nodes} nodes but got {len(hostnames)}. "
227+
f"Partial allocation not supported."
228+
)
229+
230+
return hostnames
231+
elif any(state in job_state for state in _SLURM_TERMINAL_STATES):
232+
raise RuntimeError(
233+
f"SLURM job {job_id} failed with status: {job_state}"
234+
)
235+
else:
236+
logger.debug(f"SLURM job {job_id} status: {job_state}, waiting...")
237+
238+
time.sleep(2) # Check every 2 seconds
239+
240+
raise RuntimeError(f"Timeout waiting for SLURM job {job_id} to start")
241+
242+
except Exception:
243+
# Cleanup on failure - reuse _kill() logic
244+
logger.error(f"Failed to start SLURM job {job_id}, cancelling job")
245+
self._kill()
246+
raise
247+
248+
def _state(self) -> JobState:
249+
if not self._jobs_active():
250+
raise RuntimeError("SLURM job is no longer active")
251+
252+
# Wait for job to start and get hostnames if not already done
253+
if not self._all_hostnames:
254+
job_id = self._slurm_job_id
255+
if job_id is None:
256+
raise RuntimeError("SLURM job ID is not set")
257+
total_nodes = sum(self._meshes.values())
258+
self._all_hostnames = self._wait_for_job_start(job_id, total_nodes)
259+
260+
# Distribute the allocated hostnames among meshes
261+
host_meshes = {}
262+
hostname_idx = 0
263+
264+
for mesh_name, num_nodes in self._meshes.items():
265+
mesh_hostnames = self._all_hostnames[
266+
hostname_idx : hostname_idx + num_nodes
267+
]
268+
hostname_idx += num_nodes
269+
270+
workers = [f"tcp://{hostname}:{self._port}" for hostname in mesh_hostnames]
271+
host_mesh = cast(
272+
"HostMesh",
273+
attach_to_workers(
274+
name=mesh_name,
275+
ca="trust_all_connections",
276+
workers=workers, # type: ignore[arg-type]
277+
),
278+
)
279+
host_meshes[mesh_name] = host_mesh
280+
281+
return JobState(host_meshes)
282+
283+
def can_run(self, spec: "JobTrait") -> bool:
284+
"""Check if this job can run the given spec."""
285+
return (
286+
isinstance(spec, SlurmJob)
287+
and spec._meshes == self._meshes
288+
and spec._python_exe == self._python_exe
289+
and spec._port == self._port
290+
and spec._slurm_args == self._slurm_args
291+
and spec._job_name == self._job_name
292+
and spec._ntasks_per_node == self._ntasks_per_node
293+
and spec._time_limit == self._time_limit
294+
and spec._partition == self._partition
295+
and self._jobs_active()
296+
)
297+
298+
def _jobs_active(self) -> bool:
299+
"""Check if SLURM job is still active by querying squeue."""
300+
if not self.active or self._slurm_job_id is None:
301+
return False
302+
303+
job_info = self._get_job_info_json(self._slurm_job_id)
304+
305+
if not job_info:
306+
logger.warning(f"SLURM job {self._slurm_job_id} not found in queue")
307+
return False
308+
309+
job_state = job_info.get("job_state", [])
310+
if any(state in job_state for state in _SLURM_TERMINAL_STATES):
311+
logger.warning(f"SLURM job {self._slurm_job_id} has status: {job_state}")
312+
return False
313+
314+
return True
315+
316+
def _kill(self) -> None:
317+
"""Cancel the SLURM job."""
318+
if self._slurm_job_id is not None:
319+
try:
320+
subprocess.run(
321+
["scancel", self._slurm_job_id],
322+
capture_output=True,
323+
text=True,
324+
check=True,
325+
)
326+
logger.info(f"Cancelled SLURM job {self._slurm_job_id}")
327+
except subprocess.CalledProcessError as e:
328+
logger.warning(
329+
f"Failed to cancel SLURM job {self._slurm_job_id}: {e.stderr}"
330+
)
331+
332+
self._slurm_job_id = None
333+
self._all_hostnames.clear()

python/monarch/job/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Re-export the job module directly
8-
from monarch._src.job.job import job_load, job_loads, JobState, JobTrait, LocalJob
8+
from monarch._src.job.job import (
9+
job_load,
10+
job_loads,
11+
JobState,
12+
JobTrait,
13+
LocalJob,
14+
SlurmJob,
15+
)
916

1017
# Define exports
11-
__all__ = ["JobTrait", "job_load", "job_loads", "JobState", "LocalJob"]
18+
__all__ = ["JobTrait", "job_load", "job_loads", "JobState", "LocalJob", "SlurmJob"]

0 commit comments

Comments
 (0)