Skip to content

Commit c443ea6

Browse files
committed
slurm vibe
Differential Revision: [D84191209](https://our.internmc.facebook.com/intern/diff/D84191209/) ghstack-source-id: 314932592 Pull Request resolved: #1470
1 parent 855df8a commit c443ea6

File tree

3 files changed

+370
-3
lines changed

3 files changed

+370
-3
lines changed

python/monarch/_src/job/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,6 @@ def can_run(self, spec):
467467
isinstance(spec, SSHJob)
468468
and spec._python_exe == self._python_exe
469469
and self._port == spec._port
470-
and self._ssh_args == spec._ssh_args
470+
and self._ssh_args == self._ssh_args
471471
and super().can_run(spec)
472472
)

python/monarch/_src/job/slurm.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
import os
11+
import subprocess
12+
import sys
13+
from typing import cast, Dict, List, Optional, Sequence
14+
15+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
16+
from monarch._rust_bindings.monarch_hyperactor.config import configure
17+
18+
from monarch._src.actor.bootstrap import attach_to_workers
19+
from monarch._src.actor.host_mesh import HostMesh
20+
from monarch._src.job.job import JobState, JobTrait
21+
22+
23+
logger = logging.getLogger(__name__)
24+
logger.setLevel(logging.INFO)
25+
logger.addHandler(logging.StreamHandler(sys.stderr))
26+
logger.propagate = False
27+
28+
29+
class SlurmJob(JobTrait):
30+
"""
31+
A job scheduler that uses SLURM command line tools to schedule jobs.
32+
33+
This implementation:
34+
1. Uses sbatch to submit SLURM jobs that start monarch workers
35+
2. Queries job status with squeue to get allocated hostnames
36+
3. Uses the hostnames to connect to the started workers
37+
38+
Unlike LoginJob, this submits batch jobs that can allocate multiple nodes.
39+
"""
40+
41+
def __init__(
42+
self,
43+
meshes: Dict[str, int], # mesh_name -> number of nodes
44+
python_exe: str = "python",
45+
slurm_args: Sequence[str] = (),
46+
monarch_port: int = 22222,
47+
job_name: str = "monarch_job",
48+
ntasks_per_node: int = 1,
49+
time_limit: str = "01:00:00",
50+
partition: Optional[str] = None,
51+
) -> None:
52+
configure(default_transport=ChannelTransport.Tcp)
53+
self._meshes = meshes
54+
self._python_exe = python_exe
55+
self._slurm_args = slurm_args
56+
self._port = monarch_port
57+
self._job_name = job_name
58+
self._ntasks_per_node = ntasks_per_node
59+
self._time_limit = time_limit
60+
self._partition = partition
61+
# Track the single SLURM job ID and all allocated hostnames
62+
self._slurm_job_id: Optional[str] = None
63+
self._all_hostnames: List[str] = []
64+
super().__init__()
65+
66+
def add_mesh(self, name: str, num_nodes: int) -> None:
67+
"""Add a host mesh with the specified number of nodes."""
68+
self._meshes[name] = num_nodes
69+
70+
def _create(self, client_script: Optional[str]) -> None:
71+
"""Submit a single SLURM job for all meshes."""
72+
if client_script is not None:
73+
raise RuntimeError("SlurmJob cannot run batch-mode scripts")
74+
75+
# Calculate total nodes needed across all meshes
76+
total_nodes = sum(self._meshes.values())
77+
78+
# Submit a single SLURM job for all nodes
79+
self._slurm_job_id = self._submit_slurm_job(total_nodes)
80+
81+
def _submit_slurm_job(self, num_nodes: int) -> str:
82+
"""Submit a SLURM job for all nodes."""
83+
# Create a unique job name
84+
unique_job_name = f"{self._job_name}_{os.getpid()}"
85+
86+
# Build the sbatch command
87+
sbatch_cmd = [
88+
"sbatch",
89+
"--job-name",
90+
unique_job_name,
91+
"--ntasks-per-node",
92+
str(self._ntasks_per_node),
93+
"--time",
94+
self._time_limit,
95+
"--nodes",
96+
str(num_nodes),
97+
"--output",
98+
f"/tmp/slurm_%j_{unique_job_name}.out",
99+
"--error",
100+
f"/tmp/slurm_%j_{unique_job_name}.err",
101+
]
102+
103+
# Add partition if specified
104+
if self._partition:
105+
sbatch_cmd.extend(["--partition", self._partition])
106+
107+
# Add any additional SLURM arguments
108+
sbatch_cmd.extend(self._slurm_args)
109+
110+
# Create the Python command to run on each allocated node
111+
python_command = f'import socket; from monarch.actor import run_worker_loop_forever; hostname = socket.gethostname(); run_worker_loop_forever(address=f"tcp://{{hostname}}:{self._port}", ca="trust_all_connections")'
112+
113+
# Submit the job
114+
logger.info(f"Submitting SLURM job with {num_nodes} nodes")
115+
116+
# Add the Python command as the job to execute
117+
sbatch_cmd.extend([self._python_exe, "-c", python_command])
118+
119+
try:
120+
result = subprocess.run(
121+
sbatch_cmd,
122+
capture_output=True,
123+
text=True,
124+
check=True,
125+
)
126+
127+
# Parse the job ID from sbatch output (typically "Submitted batch job 12345")
128+
job_id = None
129+
for line in result.stdout.strip().split("\n"):
130+
if "Submitted batch job" in line:
131+
job_id = line.split()[-1]
132+
break
133+
134+
if not job_id:
135+
raise RuntimeError(
136+
f"Failed to parse job ID from sbatch output: {result.stdout}"
137+
)
138+
139+
logger.info(f"SLURM job {job_id} submitted")
140+
return job_id
141+
142+
except subprocess.CalledProcessError as e:
143+
raise RuntimeError(f"Failed to submit SLURM job: {e.stderr}") from e
144+
145+
def _wait_for_job_start(
146+
self, job_id: str, expected_nodes: int, timeout: int = 300
147+
) -> List[str]:
148+
"""
149+
Wait for the SLURM job to start and return the allocated hostnames.
150+
151+
Args:
152+
job_id: The SLURM job ID
153+
expected_nodes: Expected number of nodes to be allocated
154+
timeout: Maximum time to wait in seconds
155+
156+
Returns:
157+
List of hostnames of the allocated nodes
158+
"""
159+
import time
160+
161+
start_time = time.time()
162+
163+
while time.time() - start_time < timeout:
164+
try:
165+
# Use squeue to check job status and get hostname
166+
result = subprocess.run(
167+
["squeue", "--job", job_id, "--format", "%T,%N", "--noheader"],
168+
capture_output=True,
169+
text=True,
170+
check=True,
171+
)
172+
173+
if result.stdout.strip():
174+
status, nodelist = result.stdout.strip().split(",", 1)
175+
176+
if status == "RUNNING":
177+
# Parse the nodelist to get all hostnames
178+
hostnames = self._parse_nodelist(nodelist)
179+
logger.info(
180+
f"SLURM job {job_id} is running on {len(hostnames)} nodes: {hostnames}"
181+
)
182+
183+
if len(hostnames) != expected_nodes:
184+
logger.warning(
185+
f"Expected {expected_nodes} nodes but got {len(hostnames)}"
186+
)
187+
188+
return hostnames
189+
elif status in ["FAILED", "CANCELLED", "TIMEOUT", "PREEMPTED"]:
190+
raise RuntimeError(
191+
f"SLURM job {job_id} failed with status: {status}"
192+
)
193+
else:
194+
logger.debug(f"SLURM job {job_id} status: {status}, waiting...")
195+
196+
else:
197+
# Job might be completed or not found
198+
raise RuntimeError(f"SLURM job {job_id} not found in queue")
199+
200+
except subprocess.CalledProcessError as e:
201+
logger.warning(f"Error checking job {job_id} status: {e.stderr}")
202+
203+
time.sleep(2) # Check every 2 seconds
204+
205+
raise RuntimeError(f"Timeout waiting for SLURM job {job_id} to start")
206+
207+
def _parse_nodelist(self, nodelist: str) -> List[str]:
208+
"""
209+
Parse SLURM nodelist format and return all hostnames.
210+
211+
Examples:
212+
- "node001" -> ["node001"]
213+
- "node[001-003]" -> ["node001", "node002", "node003"]
214+
- "gpu01,gpu02" -> ["gpu01", "gpu02"]
215+
"""
216+
hostnames = []
217+
218+
# Split by comma first for multiple ranges/hosts
219+
parts = [part.strip() for part in nodelist.split(",")]
220+
221+
for part in parts:
222+
if "[" in part and "]" in part:
223+
# Handle bracket notation like "node[001-003]" or "node[001,005,010-012]"
224+
base = part.split("[")[0]
225+
range_part = part.split("[")[1].split("]")[0]
226+
227+
# Handle comma-separated list inside brackets
228+
range_items = [item.strip() for item in range_part.split(",")]
229+
230+
for item in range_items:
231+
if "-" in item:
232+
# Handle range like "001-003"
233+
start_str, end_str = item.split("-")
234+
start_num = int(start_str)
235+
end_num = int(end_str)
236+
width = len(start_str) # Preserve leading zeros
237+
238+
for num in range(start_num, end_num + 1):
239+
hostname = f"{base}{str(num).zfill(width)}"
240+
hostnames.append(hostname)
241+
else:
242+
# Single number in brackets
243+
hostname = f"{base}{item}"
244+
hostnames.append(hostname)
245+
else:
246+
# Simple hostname without brackets
247+
hostnames.append(part)
248+
249+
return hostnames
250+
251+
def _state(self) -> JobState:
252+
"""Get the current state of allocated meshes."""
253+
if not self._jobs_active():
254+
raise RuntimeError("SLURM job is no longer active")
255+
256+
# Wait for job to start and get hostnames if not already done
257+
if not self._all_hostnames and self._slurm_job_id is not None:
258+
total_nodes = sum(self._meshes.values())
259+
self._all_hostnames = self._wait_for_job_start(
260+
self._slurm_job_id, total_nodes
261+
)
262+
263+
# Distribute the allocated hostnames among meshes
264+
host_meshes = {}
265+
hostname_idx = 0
266+
267+
for mesh_name, num_nodes in self._meshes.items():
268+
# Get the next num_nodes hostnames for this mesh
269+
mesh_hostnames = self._all_hostnames[
270+
hostname_idx : hostname_idx + num_nodes
271+
]
272+
hostname_idx += num_nodes
273+
274+
# Create worker addresses for each hostname
275+
workers = [f"tcp://{hostname}:{self._port}" for hostname in mesh_hostnames]
276+
host_mesh = cast(
277+
"HostMesh",
278+
attach_to_workers(
279+
name=mesh_name,
280+
ca="trust_all_connections",
281+
workers=workers, # type: ignore[arg-type]
282+
),
283+
)
284+
host_meshes[mesh_name] = host_mesh
285+
286+
return JobState(host_meshes)
287+
288+
def can_run(self, spec: "JobTrait") -> bool:
289+
"""Check if this job can run the given spec."""
290+
return (
291+
isinstance(spec, SlurmJob)
292+
and spec._meshes == self._meshes
293+
and spec._python_exe == self._python_exe
294+
and spec._port == self._port
295+
and spec._slurm_args == self._slurm_args
296+
and spec._job_name == self._job_name
297+
and spec._ntasks_per_node == self._ntasks_per_node
298+
and spec._time_limit == self._time_limit
299+
and spec._partition == self._partition
300+
and self._jobs_active()
301+
)
302+
303+
def _jobs_active(self) -> bool:
304+
"""Check if SLURM job is still active by querying squeue."""
305+
if not self.active or self._slurm_job_id is None:
306+
return False
307+
308+
try:
309+
# Check if the job is still in the queue
310+
result = subprocess.run(
311+
["squeue", "--job", self._slurm_job_id, "--format", "%T", "--noheader"],
312+
capture_output=True,
313+
text=True,
314+
check=True,
315+
)
316+
317+
if result.stdout.strip():
318+
status = result.stdout.strip()
319+
if status in [
320+
"FAILED",
321+
"CANCELLED",
322+
"TIMEOUT",
323+
"PREEMPTED",
324+
"COMPLETED",
325+
]:
326+
logger.warning(
327+
f"SLURM job {self._slurm_job_id} has status: {status}"
328+
)
329+
return False
330+
else:
331+
# Job not in queue anymore
332+
logger.warning(f"SLURM job {self._slurm_job_id} not found in queue")
333+
return False
334+
335+
except subprocess.CalledProcessError as e:
336+
logger.warning(
337+
f"Error checking job {self._slurm_job_id} status: {e.stderr}"
338+
)
339+
return False
340+
341+
return True
342+
343+
def _kill(self) -> None:
344+
"""Cancel the SLURM job."""
345+
if self._slurm_job_id is not None:
346+
try:
347+
subprocess.run(
348+
["scancel", self._slurm_job_id],
349+
capture_output=True,
350+
text=True,
351+
check=True,
352+
)
353+
logger.info(f"Cancelled SLURM job {self._slurm_job_id}")
354+
except subprocess.CalledProcessError as e:
355+
logger.warning(
356+
f"Failed to cancel SLURM job {self._slurm_job_id}: {e.stderr}"
357+
)
358+
359+
self._slurm_job_id = None
360+
self._all_hostnames.clear()

python/monarch/job/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Re-export the job module directly
8-
from monarch._src.job.job import job_load, job_loads, JobState, JobTrait, LocalJob
8+
from monarch._src.job.job import (
9+
job_load,
10+
job_loads,
11+
JobState,
12+
JobTrait,
13+
LocalJob,
14+
SlurmJob,
15+
)
916

1017
# Define exports
11-
__all__ = ["JobTrait", "job_load", "job_loads", "JobState", "LocalJob"]
18+
__all__ = ["JobTrait", "job_load", "job_loads", "JobState", "LocalJob", "SlurmJob"]

0 commit comments

Comments
 (0)