diff --git a/src/cloudai/_core/command_gen_strategy.py b/src/cloudai/_core/command_gen_strategy.py index 5238bb675..0268fbca5 100644 --- a/src/cloudai/_core/command_gen_strategy.py +++ b/src/cloudai/_core/command_gen_strategy.py @@ -49,6 +49,10 @@ def store_test_run(self) -> None: """ pass + def cleanup_job_artifacts(self) -> None: + """Best-effort cleanup hook run after the job has fully completed.""" + return + @property def final_env_vars(self) -> dict[str, str | list[str]]: if not self._final_env_vars: diff --git a/src/cloudai/systems/slurm/single_sbatch_runner.py b/src/cloudai/systems/slurm/single_sbatch_runner.py index c346a0e64..6f763f4b8 100644 --- a/src/cloudai/systems/slurm/single_sbatch_runner.py +++ b/src/cloudai/systems/slurm/single_sbatch_runner.py @@ -22,7 +22,7 @@ from typing import Generator, Optional, cast from cloudai.configurator import CloudAIGymEnv, TrajectoryEntry -from cloudai.core import JobIdRetrievalError, System, TestRun, TestScenario +from cloudai.core import BaseJob, JobIdRetrievalError, System, TestRun, TestScenario from cloudai.util import CommandShell, format_time_limit, parse_time_limit from .slurm_command_gen_strategy import SlurmCommandGenStrategy @@ -221,6 +221,9 @@ def handle_dse(self): ) ) + def completed_test_runs(self, job: BaseJob) -> list[TestRun]: + return list(self.all_trs) + def _submit_test(self, tr: TestRun) -> SlurmJob: with open(self.scenario_root / "cloudai_sbatch_script.sh", "w") as f: f.write(self.gen_sbatch_content()) diff --git a/src/cloudai/systems/slurm/slurm_runner.py b/src/cloudai/systems/slurm/slurm_runner.py index 50a70082d..fd8f0902e 100644 --- a/src/cloudai/systems/slurm/slurm_runner.py +++ b/src/cloudai/systems/slurm/slurm_runner.py @@ -77,10 +77,18 @@ def on_job_submit(self, tr: TestRun) -> None: cmd_gen = self.get_cmd_gen_strategy(self.system, tr) cmd_gen.store_test_run() + def completed_test_runs(self, job: BaseJob) -> list[TestRun]: + return [cast(SlurmJob, job).test_run] + def on_job_completion(self, job: BaseJob) -> None: logging.debug(f"Job completion callback for job {job.id}") self.system.complete_job(cast(SlurmJob, job)) self.store_job_metadata(cast(SlurmJob, job)) + for tr in self.completed_test_runs(job): + try: + self.get_cmd_gen_strategy(self.system, tr).cleanup_job_artifacts() + except Exception: + logging.warning(f"Cleanup failed for test run at {tr.output_path}", exc_info=True) def _mock_job_metadata(self) -> SlurmStepMetadata: return SlurmStepMetadata( diff --git a/src/cloudai/workloads/common/nixl.py b/src/cloudai/workloads/common/nixl.py index fc35d6dab..950f91707 100644 --- a/src/cloudai/workloads/common/nixl.py +++ b/src/cloudai/workloads/common/nixl.py @@ -17,6 +17,7 @@ import logging import re +import shutil from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast @@ -231,6 +232,24 @@ def _unique_file_name(self, file_name: str, used_filenames: set[str]) -> str: used_filenames.add(candidate) return candidate + def cleanup_job_artifacts(self) -> None: + for cleanup_target in self._cleanup_targets(): + if cleanup_target.exists(): + shutil.rmtree(cleanup_target) + + def _cleanup_targets(self) -> list[Path]: + cleanup_targets: list[Path] = [] + + filepath_raw: str | None = cast(str | None, self.test_run.test.cmd_args_dict.get("filepath")) + if filepath_raw: + cleanup_targets.append((self.test_run.output_path / "filepath_mount").resolve()) + + device_list_raw: str | None = cast(str | None, self.test_run.test.cmd_args_dict.get("device_list")) + if device_list_raw and get_files_from_device_list(device_list_raw): + cleanup_targets.append((self.test_run.output_path / "device_list_mounts").resolve()) + + return cleanup_targets + @property def final_env_vars(self) -> dict[str, str | list[str]]: env_vars = super().final_env_vars diff --git a/tests/test_get_job_id.py b/tests/test_get_job_id.py index 260593dea..ecdf6ced3 100644 --- a/tests/test_get_job_id.py +++ b/tests/test_get_job_id.py @@ -16,14 +16,14 @@ import subprocess from pathlib import Path -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from cloudai.core import JobIdRetrievalError, TestRun, TestScenario from cloudai.systems.lsf.lsf_runner import LSFRunner from cloudai.systems.lsf.lsf_system import LSFSystem -from cloudai.systems.slurm import SlurmRunner, SlurmSystem +from cloudai.systems.slurm import SlurmJob, SlurmRunner, SlurmSystem from cloudai.util import CommandShell from cloudai.workloads.sleep.sleep import SleepCmdArgs, SleepTestDefinition @@ -88,6 +88,21 @@ def test_slurm_get_job_id(slurm_runner: SlurmRunner, stdout: str, stderr: str, e assert res == expected_job_id +def test_slurm_runner_on_job_completion_calls_cleanup(slurm_runner: SlurmRunner): + tr = slurm_runner.test_scenario.test_runs[0] + job = SlurmJob(tr, id=1) + slurm_runner.store_job_metadata = Mock() + cleanup = Mock() + slurm_runner.get_cmd_gen_strategy = Mock(return_value=Mock(cleanup_job_artifacts=cleanup)) + + with patch.object(SlurmSystem, "complete_job") as complete_job: + slurm_runner.on_job_completion(job) + + complete_job.assert_called_once_with(job) + slurm_runner.store_job_metadata.assert_called_once_with(job) + cleanup.assert_called_once() + + @pytest.mark.parametrize( "stdout, stderr, expected_job_id", [ diff --git a/tests/test_single_sbatch_runner.py b/tests/test_single_sbatch_runner.py index 72ad93f79..91d3cdf27 100644 --- a/tests/test_single_sbatch_runner.py +++ b/tests/test_single_sbatch_runner.py @@ -16,8 +16,9 @@ import copy import re +from pathlib import Path from typing import Generator, Optional, cast -from unittest.mock import Mock +from unittest.mock import Mock, patch import pandas as pd import pytest @@ -506,6 +507,30 @@ def test_store_job_metadata(nccl_tr: TestRun, slurm_system: SlurmSystem) -> None assert sjm == SlurmJobMetadata.model_validate(toml.loads(toml.dumps(sjm.model_dump()))) +def test_on_job_completion_cleans_all_effective_test_runs( + dse_tr: TestRun, nccl_tr: TestRun, slurm_system: SlurmSystem +) -> None: + tc = TestScenario(name="tc", test_runs=[dse_tr, nccl_tr]) + runner = SingleSbatchRunner(mode="run", system=slurm_system, test_scenario=tc, output_path=slurm_system.output_path) + runner.mode = "dry-run" + runner.store_job_metadata = Mock() + + cleanup_calls: list[Path] = [] + + def _cmd_gen(_, tr: TestRun): + return Mock(cleanup_job_artifacts=Mock(side_effect=lambda: cleanup_calls.append(tr.output_path))) + + runner.get_cmd_gen_strategy = Mock(side_effect=_cmd_gen) + + expected_paths = [tr.output_path for tr in runner.all_trs] + job = SlurmJob(nccl_tr, id=1) + + with patch.object(SlurmSystem, "complete_job"): + runner.on_job_completion(job) + + assert cleanup_calls == expected_paths + + def test_pre_test(nccl_tr: TestRun, sleep_tr: TestRun, slurm_system: SlurmSystem) -> None: nccl_tr.pre_test = TestScenario(name="pre_test", test_runs=[sleep_tr]) tc = TestScenario(name="tc", test_runs=[nccl_tr]) diff --git a/tests/workloads/nixl_bench/test_command_gen_strategy_slurm.py b/tests/workloads/nixl_bench/test_command_gen_strategy_slurm.py index f63802e6e..814b20e7b 100644 --- a/tests/workloads/nixl_bench/test_command_gen_strategy_slurm.py +++ b/tests/workloads/nixl_bench/test_command_gen_strategy_slurm.py @@ -105,6 +105,32 @@ def test_container_mounts(self, nixl_bench_tr: TestRun, slurm_system: SlurmSyste assert (nixl_bench_tr.output_path / "device_list_mounts" / local_device_filename).is_file() assert (nixl_bench_tr.output_path / "device_list_mounts" / local_device_filename).stat().st_size == 1024 + def test_cleanup_job_artifacts(self, nixl_bench_tr: TestRun, slurm_system: SlurmSystem): + nixl_bench_tr.test.cmd_args = NIXLBenchCmdArgs.model_validate( + { + "docker_image_url": "docker.io/library/ubuntu:22.04", + "path_to_benchmark": "/nixlbench", + "backend": "GUSLI", + "device_list": "11:K:/dev/nvme0n1,12:F:/p1/store0.bin,13:F:/p2/store0.bin", + "filepath": "/data", + } + ) + strategy = NIXLBenchSlurmCommandGenStrategy(slurm_system, nixl_bench_tr) + filepath_dir = nixl_bench_tr.output_path / "filepath_mount" + device_list_dir = nixl_bench_tr.output_path / "device_list_mounts" + other_file = nixl_bench_tr.output_path / "keep.txt" + filepath_dir.mkdir(parents=True, exist_ok=True) + device_list_dir.mkdir(parents=True, exist_ok=True) + (filepath_dir / "a.txt").write_text("x") + (device_list_dir / "b.txt").write_text("x") + other_file.write_text("keep") + + strategy.cleanup_job_artifacts() + + assert not filepath_dir.exists() + assert not device_list_dir.exists() + assert other_file.exists() + @pytest.mark.parametrize( ("override", "expected_error_match", "expected_total_buffer_size"), ( diff --git a/tests/workloads/nixl_kvbench/test_command_gen_slurm.py b/tests/workloads/nixl_kvbench/test_command_gen_slurm.py index e9c595828..fecf1d371 100644 --- a/tests/workloads/nixl_kvbench/test_command_gen_slurm.py +++ b/tests/workloads/nixl_kvbench/test_command_gen_slurm.py @@ -38,8 +38,10 @@ def kvbench() -> NIXLKVBenchTestDefinition: @pytest.fixture -def kvbench_tr(kvbench: NIXLKVBenchTestDefinition) -> TestRun: - return TestRun(name="nixl-bench", num_nodes=2, nodes=[], test=kvbench) +def kvbench_tr(kvbench: NIXLKVBenchTestDefinition, tmp_path) -> TestRun: + output_path = tmp_path / "nixl-kvbench" + output_path.mkdir(parents=True, exist_ok=True) + return TestRun(name="nixl-bench", num_nodes=2, nodes=[], test=kvbench, output_path=output_path) def test_gen_kvbench_ucx(kvbench_tr: TestRun, slurm_system: SlurmSystem): @@ -124,3 +126,29 @@ def test_get_etcd_srun_command_with_etcd_image(kvbench_tr: TestRun, slurm_system cmd = " ".join(strategy.gen_etcd_srun_command(tdef.cmd_args.etcd_path)) assert tdef.etcd_image is not None assert f"--container-image={tdef.etcd_image.installed_path}" in cmd + + +def test_kvbench_cleanup_job_artifacts(kvbench_tr: TestRun, slurm_system: SlurmSystem): + kvbench_tr.test.cmd_args = NIXLKVBenchCmdArgs.model_validate( + { + "docker_image_url": "docker://image/url", + "backend": "GUSLI", + "filepath": "/data", + "device_list": "11:F:/store0.bin", + } + ) + strategy = NIXLKVBenchSlurmCommandGenStrategy(slurm_system, kvbench_tr) + filepath_dir = kvbench_tr.output_path / "filepath_mount" + device_list_dir = kvbench_tr.output_path / "device_list_mounts" + other_file = kvbench_tr.output_path / "keep.txt" + filepath_dir.mkdir(parents=True, exist_ok=True) + device_list_dir.mkdir(parents=True, exist_ok=True) + (filepath_dir / "a.txt").write_text("x") + (device_list_dir / "b.txt").write_text("x") + other_file.write_text("keep") + + strategy.cleanup_job_artifacts() + + assert not filepath_dir.exists() + assert not device_list_dir.exists() + assert other_file.exists()