diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 8f740cfa6..8ec7b51c1 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -139,9 +139,13 @@ jobs: source /root/.venv/bin/activate pip install --no-input ".[dev,render]" + python tests/monitor_test_mem.py --out-csv-filepath "/mnt/data/artifacts/mem_test_${SLURM_JOB_NAME}.csv" & + pytest --print -x -m "benchmarks" ./tests cat speed_test*.txt > "/mnt/data/artifacts/speed_test_${SLURM_JOB_NAME}.txt" + kill $(ps -ef | grep monitor_test_mem | grep -v grep | awk '{print $2}') + # tmate -S /tmp/tmate.sock wait tmate-exit EOF - name: Kill srun job systematically @@ -158,3 +162,8 @@ jobs: with: name: speed-test-${{ matrix.GS_ENABLE_NDARRAY }} path: "/mnt/data/artifacts/speed_test_${{ env.SLURM_JOB_NAME }}.txt" + - name: Upload benchmark mem stats as artifact + uses: actions/upload-artifact@v4 + with: + name: mem-test-${{ matrix.GS_ENABLE_NDARRAY }} + path: "/mnt/data/artifacts/mem_test_${{ env.SLURM_JOB_NAME }}.csv" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 24a20956f..0a7c55214 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dev = [ "pytest-print", # - 16.0 is causing pytest-xdist to crash in case of failure or skipped tests "pytest-rerunfailures!=16.0", + "setproctitle", # allows renaming the test processes on the cluster "syrupy", "huggingface_hub[hf_xet]", "wandb", diff --git a/tests/conftest.py b/tests/conftest.py index 2335d4c67..3f3cf5de8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from io import BytesIO from pathlib import Path +import setproctitle import psutil import pyglet import pytest @@ -64,7 +65,14 @@ def pytest_make_parametrize_id(config, val, argname): return f"{val}" -@pytest.hookimpl +@pytest.hookimpl(tryfirst=True) +def pytest_runtest_setup(item): + # Include test name in process title + test_name = item.nodeid.replace(" ", "") + setproctitle.setproctitle(f"pytest: {test_name}") + + +@pytest.hookimpl(tryfirst=True) def pytest_cmdline_main(config: pytest.Config) -> None: # Force disabling forked for non-linux systems if not sys.platform.startswith("linux"): diff --git a/tests/monitor_test_mem.py b/tests/monitor_test_mem.py new file mode 100644 index 000000000..c60dd07ec --- /dev/null +++ b/tests/monitor_test_mem.py @@ -0,0 +1,94 @@ +from collections import defaultdict +import csv +import subprocess +import time +import argparse + + +def grep(contents: list[str], target): + return [l for l in contents if target in l] + + +def get_cuda_usage() -> dict[int, int]: + output = subprocess.check_output(["nvidia-smi"]).decode("utf-8") + section = 0 + subsec = 0 + res = {} + for line in output.split("\n"): + if line.startswith("|============"): + section += 1 + subsec = 0 + continue + if line.startswith("+-------"): + subsec += 1 + continue + if section == 2 and subsec == 0: + if "No running processes" in line: + continue + split_line = line.split() + pid = int(split_line[4]) + mem = int(split_line[-2].split("MiB")[0]) + res[pid] = mem + return res + + +def get_test_name_by_pid() -> dict[int, str]: + ps_ef = subprocess.check_output(["ps", "-ef"]).decode("utf-8").split("\n") + test_lines = grep(ps_ef, "pytest-xdist") + tests = [line.partition("::")[2] for line in test_lines] + psids = [int(line.split()[1]) for line in test_lines] + test_by_psid = {psid: test for test, psid in zip(tests, psids) if test.strip() != ""} + return test_by_psid + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--out-csv-filepath", type=str, required=True) + args = parser.parse_args() + + max_mem_by_test = defaultdict(int) + + f = open(args.out_csv_filepath, "w") + dict_writer = csv.DictWriter(f, fieldnames=["test", "max_mem_mb"]) + dict_writer.writeheader() + old_mem_by_test = {} + num_results_written = 0 + disp = False + while True: + mem_by_pid = get_cuda_usage() + test_by_psid = get_test_name_by_pid() + num_tests = len(test_by_psid) + _mem_by_test = {} + for psid, test in test_by_psid.items(): + if psid not in mem_by_pid: + continue + if test.strip() == "": + continue + _mem = mem_by_pid[psid] + _mem_by_test[test] = _mem + for test, _mem in _mem_by_test.items(): + max_mem_by_test[test] = max(_mem, max_mem_by_test[test]) + for _test, _mem in old_mem_by_test.items(): + if _test not in _mem_by_test: + dict_writer.writerow({"test": _test, "max_mem_mb": max_mem_by_test[_test]}) + f.flush() + num_results_written += 1 + spinny = "x" if disp else "+" + print( + num_tests, + "tests running, of which", + len(_mem_by_test), + "on gpu. Num results written: ", + num_results_written, + "[updating]", + " ", + end="\r", + flush=True, + ) + old_mem_by_test = _mem_by_test + disp = not disp + time.sleep(1.0) + + +if __name__ == "__main__": + main()