Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/production.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,13 @@ jobs:
source /root/.venv/bin/activate
pip install --no-input ".[dev,render]"

python tests/monitor_test_mem.py --out-csv-filepath "/mnt/data/artifacts/mem_test_${SLURM_JOB_NAME}.csv" &

pytest --print -x -m "benchmarks" ./tests
cat speed_test*.txt > "/mnt/data/artifacts/speed_test_${SLURM_JOB_NAME}.txt"

kill $(ps -ef | grep monitor_test_mem | grep -v grep | awk '{print $2}')

# tmate -S /tmp/tmate.sock wait tmate-exit
EOF
- name: Kill srun job systematically
Expand All @@ -158,3 +162,8 @@ jobs:
with:
name: speed-test-${{ matrix.GS_ENABLE_NDARRAY }}
path: "/mnt/data/artifacts/speed_test_${{ env.SLURM_JOB_NAME }}.txt"
- name: Upload benchmark mem stats as artifact
uses: actions/upload-artifact@v4
with:
name: mem-test-${{ matrix.GS_ENABLE_NDARRAY }}
path: "/mnt/data/artifacts/mem_test_${{ env.SLURM_JOB_NAME }}.csv"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ dev = [
"pytest-print",
# - 16.0 is causing pytest-xdist to crash in case of failure or skipped tests
"pytest-rerunfailures!=16.0",
"setproctitle", # allows renaming the test processes on the cluster
"syrupy",
"huggingface_hub[hf_xet]",
"wandb",
Expand Down
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from io import BytesIO
from pathlib import Path

import setproctitle
import psutil
import pyglet
import pytest
Expand Down Expand Up @@ -64,7 +65,14 @@ def pytest_make_parametrize_id(config, val, argname):
return f"{val}"


@pytest.hookimpl
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
# Include test name in process title
test_name = item.nodeid.replace(" ", "")
setproctitle.setproctitle(f"pytest: {test_name}")


@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: pytest.Config) -> None:
# Force disabling forked for non-linux systems
if not sys.platform.startswith("linux"):
Expand Down
94 changes: 94 additions & 0 deletions tests/monitor_test_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from collections import defaultdict
import csv
import subprocess
import time
import argparse


def grep(contents: list[str], target):
return [l for l in contents if target in l]


def get_cuda_usage() -> dict[int, int]:
output = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
section = 0
subsec = 0
res = {}
for line in output.split("\n"):
if line.startswith("|============"):
section += 1
subsec = 0
continue
if line.startswith("+-------"):
subsec += 1
continue
if section == 2 and subsec == 0:
if "No running processes" in line:
continue
split_line = line.split()
pid = int(split_line[4])
mem = int(split_line[-2].split("MiB")[0])
res[pid] = mem
return res


def get_test_name_by_pid() -> dict[int, str]:
ps_ef = subprocess.check_output(["ps", "-ef"]).decode("utf-8").split("\n")
test_lines = grep(ps_ef, "pytest-xdist")
tests = [line.partition("::")[2] for line in test_lines]
psids = [int(line.split()[1]) for line in test_lines]
test_by_psid = {psid: test for test, psid in zip(tests, psids) if test.strip() != ""}
return test_by_psid


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--out-csv-filepath", type=str, required=True)
args = parser.parse_args()

max_mem_by_test = defaultdict(int)

f = open(args.out_csv_filepath, "w")
dict_writer = csv.DictWriter(f, fieldnames=["test", "max_mem_mb"])
dict_writer.writeheader()
old_mem_by_test = {}
num_results_written = 0
disp = False
while True:
mem_by_pid = get_cuda_usage()
test_by_psid = get_test_name_by_pid()
num_tests = len(test_by_psid)
_mem_by_test = {}
for psid, test in test_by_psid.items():
if psid not in mem_by_pid:
continue
if test.strip() == "":
continue
_mem = mem_by_pid[psid]
_mem_by_test[test] = _mem
for test, _mem in _mem_by_test.items():
max_mem_by_test[test] = max(_mem, max_mem_by_test[test])
for _test, _mem in old_mem_by_test.items():
if _test not in _mem_by_test:
dict_writer.writerow({"test": _test, "max_mem_mb": max_mem_by_test[_test]})
f.flush()
num_results_written += 1
spinny = "x" if disp else "+"
print(
num_tests,
"tests running, of which",
len(_mem_by_test),
"on gpu. Num results written: ",
num_results_written,
"[updating]",
" ",
end="\r",
flush=True,
)
old_mem_by_test = _mem_by_test
disp = not disp
time.sleep(1.0)


if __name__ == "__main__":
main()
Loading