Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,102 @@ def _query_slurm_jobs_status(
hostname: str,
socket: str | None,
) -> Dict[str, str]:
"""Query SLURM for job statuses using sacct command.
"""Query SLURM for job statuses using squeue (for active jobs) and sacct (fallback).

This function first tries squeue which is more accurate for currently running jobs,
then falls back to sacct for completed/historical jobs that squeue doesn't show.

Args:
slurm_job_ids: List of SLURM job IDs to query.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that most of the bugs are coming from the fact that we cannot reliably tell the slurm job id for a specific job. We are trying to read this from a file, but there can be some race conditions and manual restarts that can make the file to be out-of-sync from reality.

For the concrete case we discussed offline, will this fix the status?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to handle cases where a user does something manually, e.g. restarts the job.

Also I think the file with job IDs is the closest thing to the truth that we can get. If we tried to get the information from all user's jobs, we'd open a new can of worms - most folks run different things, not only evaluations, it's hard to predict what corner cases we'd hit.

username: SSH username.
hostname: SSH hostname.
socket: control socket location or None

Returns:
Dict mapping from slurm_job_id to returned slurm status.
"""
if len(slurm_job_ids) == 0:
return {}

# First, try squeue for active jobs (more accurate for running jobs)
squeue_statuses = _query_squeue_for_jobs(slurm_job_ids, username, hostname, socket)

# For jobs not found in squeue, fall back to sacct
missing_jobs = [job_id for job_id in slurm_job_ids if job_id not in squeue_statuses]
sacct_statuses = {}

if missing_jobs:
sacct_statuses = _query_sacct_for_jobs(missing_jobs, username, hostname, socket)

# Combine results, preferring squeue data
combined_statuses = {**sacct_statuses, **squeue_statuses}

return combined_statuses


def _query_squeue_for_jobs(
slurm_job_ids: List[str],
username: str,
hostname: str,
socket: str | None,
) -> Dict[str, str]:
"""Query SLURM for active job statuses using squeue command.

Args:
slurm_job_ids: List of SLURM job IDs to query.
username: SSH username.
hostname: SSH hostname.
socket: control socket location or None

Returns:
Dict mapping from slurm_job_id to returned slurm status for active jobs only.
"""
if len(slurm_job_ids) == 0:
return {}

# Use squeue to get active jobs - more accurate than sacct for running jobs
squeue_command = "squeue -u {} -h -o '%i %T'".format(username)

ssh_command = ["ssh"]
if socket is not None:
ssh_command.append(f"-S {socket}")
ssh_command.append(f"{username}@{hostname}")
ssh_command.append(squeue_command)
ssh_command = " ".join(ssh_command)

completed_process = subprocess.run(
args=shlex.split(ssh_command),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

squeue_statuses = {}
if completed_process.returncode == 0:
squeue_output = completed_process.stdout.decode("utf-8")
squeue_output_lines = squeue_output.strip().split("\n")

for line in squeue_output_lines:
if not line.strip():
continue
parts = line.split()
if len(parts) >= 2:
job_id = parts[0]
status = parts[1]
# Extract base job ID (handle array jobs like 123456_0 -> 123456)
base_job_id = job_id.split("_")[0].split("[")[0]
if base_job_id in slurm_job_ids:
squeue_statuses[base_job_id] = status

return squeue_statuses


def _query_sacct_for_jobs(
slurm_job_ids: List[str],
username: str,
hostname: str,
socket: str | None,
) -> Dict[str, str]:
"""Query SLURM for job statuses using sacct command (for completed/historical jobs).

Args:
slurm_job_ids: List of SLURM job IDs to query.
Expand All @@ -998,6 +1093,7 @@ def _query_slurm_jobs_status(
"""
if len(slurm_job_ids) == 0:
return {}

sacct_command = "sacct -j {} --format='JobID,State%32' --noheader -P".format(
",".join(slurm_job_ids)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,103 @@ def mock_subprocess_run(*args, **kwargs):
socket="/tmp/socket",
)

def test_query_squeue_for_jobs_success(self):
"""Test _query_squeue_for_jobs function with successful subprocess call."""
from nemo_evaluator_launcher.executors.slurm.executor import (
_query_squeue_for_jobs,
)

def mock_subprocess_run(*args, **kwargs):
"""Mock subprocess.run for squeue command."""
# Mock squeue output with various job formats
return Mock(
returncode=0,
stdout=b"123456789 RUNNING\n123456790_0 PENDING\n123456791[1-10] PENDING\n",
stderr=b"",
)

with patch("subprocess.run", side_effect=mock_subprocess_run):
result = _query_squeue_for_jobs(
slurm_job_ids=["123456789", "123456790", "123456791"],
username="testuser",
hostname="slurm.example.com",
socket="/tmp/socket",
)

assert result == {
"123456789": "RUNNING",
"123456790": "PENDING",
"123456791": "PENDING",
}

def test_query_slurm_jobs_status_combined_approach(self):
"""Test _query_slurm_jobs_status using combined squeue + sacct approach."""
from nemo_evaluator_launcher.executors.slurm.executor import (
_query_slurm_jobs_status,
)

def mock_subprocess_run(*args, **kwargs):
"""Mock subprocess.run for both squeue and sacct commands."""
# Get the command from kwargs['args'] since that's how subprocess.run is called
cmd_args = kwargs.get("args", [])
if not cmd_args:
return Mock(returncode=1, stdout=b"", stderr=b"")

cmd_str = (
" ".join(cmd_args) if isinstance(cmd_args, list) else str(cmd_args)
)

if "squeue" in cmd_str:
# Mock squeue showing only running jobs
return Mock(
returncode=0,
stdout=b"123456789 RUNNING\n",
stderr=b"",
)
elif "sacct" in cmd_str:
# Mock sacct showing completed job that's not in squeue
return Mock(
returncode=0,
stdout=b"123456790|COMPLETED\n",
stderr=b"",
)
return Mock(returncode=1, stdout=b"", stderr=b"")

with patch("subprocess.run", side_effect=mock_subprocess_run):
result = _query_slurm_jobs_status(
slurm_job_ids=["123456789", "123456790"],
username="testuser",
hostname="slurm.example.com",
socket="/tmp/socket",
)

# Should get running job from squeue and completed job from sacct
assert result == {"123456789": "RUNNING", "123456790": "COMPLETED"}

def test_query_sacct_for_jobs_success(self):
"""Test _query_sacct_for_jobs function with successful subprocess call."""
from nemo_evaluator_launcher.executors.slurm.executor import (
_query_sacct_for_jobs,
)

def mock_subprocess_run(*args, **kwargs):
"""Mock subprocess.run for sacct command."""
return Mock(
returncode=0,
stdout=b"123456789|COMPLETED\n123456790|FAILED\n",
stderr=b"",
)

with patch("subprocess.run", side_effect=mock_subprocess_run):
result = _query_sacct_for_jobs(
slurm_job_ids=["123456789", "123456790"],
username="testuser",
hostname="slurm.example.com",
socket="/tmp/socket",
)

assert result == {"123456789": "COMPLETED", "123456790": "FAILED"}

def test_sbatch_remote_runsubs_success(self):
"""Test _sbatch_remote_runsubs function with successful subprocess call."""
from pathlib import Path
Expand Down