Skip to content

Commit 1ee5366

Browse files
authored
feat: added script for processing wandb system metrics (#167)
* feat: added script * fix: moved script * fix * fix
1 parent 931bd04 commit 1ee5366

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,4 @@ module = [
210210
"autointent._callbacks.*",
211211
"autointent.modules.abc.*",
212212
]
213-
warn_unreachable = false
213+
warn_unreachable = false

scripts/wandb_resources_cli.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""CLI tool for processing wandb runs and computing system metrics."""
2+
3+
import argparse
4+
import re
5+
from collections.abc import Sequence
6+
from typing import Any
7+
8+
from scipy.integrate import trapezoid
9+
10+
import wandb
11+
12+
13+
def calculate_area(metrics: Sequence[float], timestamps: Sequence[float]) -> Any:
14+
"""Calculate the area under the curve using the trapezoidal rule.
15+
16+
Args:
17+
metrics (iterable): Array of metric values.
18+
timestamps (iterable): Array of timestamps corresponding to the metric values.
19+
20+
Returns:
21+
float: The computed area under the curve.
22+
23+
Raises:
24+
ValueError: If the lengths of metrics and timestamps do not match.
25+
"""
26+
if len(metrics) != len(timestamps):
27+
error = "Metrics and timestamps dimensions do not match!"
28+
raise ValueError(error)
29+
return trapezoid(metrics, timestamps)
30+
31+
32+
def sanitize_filename(filename: str) -> str:
33+
"""Sanitize a filename by replacing invalid characters with underscores.
34+
35+
Args:
36+
filename (str): The filename to sanitize.
37+
38+
Returns:
39+
str: A sanitized filename containing only alphanumeric characters, underscores, or hyphens.
40+
"""
41+
return re.sub(r"[^a-zA-Z0-9_\-]", "_", filename)
42+
43+
44+
def process_run(run: wandb.Api.Run) -> dict[str, dict[str, float]]:
45+
"""Process a wandb run to extract system metrics and compute statistics.
46+
47+
Args:
48+
run (Any): A wandb run object containing history data.
49+
50+
Returns:
51+
dict: A dictionary mapping metric columns to their computed statistics,
52+
including max, min, avg, area, and median.
53+
"""
54+
system_metrics = run.history(stream="system_metrics")
55+
results = {}
56+
for column in system_metrics.columns:
57+
if "system" not in column:
58+
continue
59+
metrics_values = system_metrics[column].dropna()
60+
if len(metrics_values) > 0:
61+
timestamps = system_metrics.index[: len(metrics_values)]
62+
results[column] = {
63+
"max": metrics_values.max(),
64+
"min": metrics_values.min(),
65+
"avg": metrics_values.mean(),
66+
"area": calculate_area(metrics_values, timestamps),
67+
"median": metrics_values.median(),
68+
}
69+
return results
70+
71+
72+
def main() -> None:
73+
"""CLI endpoint."""
74+
parser = argparse.ArgumentParser(description="Processing wandb resource data")
75+
parser.add_argument("--project", type=str, required=True, help="Wandb project name")
76+
parser.add_argument("--group", type=str, required=True, help="Wandb group name")
77+
parser.add_argument(
78+
"--metrics",
79+
type=str,
80+
nargs="+",
81+
default=["max", "area", "avg"],
82+
help="Metrics to compute (e.g. min, max, area, avg, median)",
83+
)
84+
args = parser.parse_args()
85+
86+
api = wandb.Api()
87+
runs = api.runs(args.project, filters={"group": args.group})
88+
89+
if not runs:
90+
error = f"No runs found for group {args.group} in project {args.project}."
91+
raise ValueError(error)
92+
93+
for run in runs:
94+
if "final_metrics" not in run.name:
95+
wandb.init(project=args.project, group=args.group, name=f"system_resources_{run.name}")
96+
results = process_run(run)
97+
98+
for column, metrics in results.items():
99+
column_process = column.replace("/", "-")
100+
log_data = {}
101+
102+
for m in args.metrics:
103+
if m in metrics.keys():
104+
log_data[f"system_resources/{column_process}_{m}"] = metrics[m]
105+
106+
wandb.log(log_data)
107+
108+
wandb.finish()
109+
110+
111+
if __name__ == "__main__":
112+
main()

0 commit comments

Comments
 (0)