Skip to content

Commit cc5edb1

Browse files
committed
feat: add block production graph script and deps
Add `block_production_graph.py` script to generate a bar chart of total blocks per backend from an SQLite database. Update dependencies in `pyproject.toml` and `poetry.lock` to include matplotlib and seaborn, enabling statistical data visualization. Register the script as a CLI entry point under `block-production-graph`.
1 parent ecf6c12 commit cc5edb1

File tree

3 files changed

+881
-1
lines changed

3 files changed

+881
-1
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#!/usr/bin/env python3
2+
"""Generate a bar chart of total blocks per backend from an SQLite database.
3+
4+
The script retrieves the latest run_id from the runs table, aggregates the total
5+
blocks per backend from the blocks table, and generates a bar chart saved as an image file.
6+
"""
7+
8+
import argparse
9+
import contextlib
10+
import pathlib as pl
11+
import sqlite3
12+
import sys
13+
14+
import matplotlib.container as mcontainer
15+
import matplotlib.pyplot as plt
16+
import pandas as pd
17+
import seaborn as sns
18+
19+
20+
def get_latest_run_id(conn: sqlite3.Connection) -> str:
21+
"""Return the latest run_id from the runs table.
22+
23+
Assumes "latest" means last inserted row (highest rowid).
24+
"""
25+
cur = conn.cursor()
26+
cur.execute("SELECT run_id FROM runs ORDER BY rowid DESC LIMIT 1;")
27+
row = cur.fetchone()
28+
if row is None:
29+
err = "No runs found in 'runs' table."
30+
raise RuntimeError(err)
31+
return str(row[0])
32+
33+
34+
def get_blocks_per_backend(conn: sqlite3.Connection, *, run_id: str) -> list[tuple[str, int]]:
35+
"""Return a list of (backend, total_blocks) for the given run_id.
36+
37+
Aggregates num_blocks across all epochs and pools.
38+
"""
39+
cur = conn.cursor()
40+
cur.execute(
41+
"""
42+
SELECT backend, SUM(num_blocks) AS total_blocks
43+
FROM blocks
44+
WHERE run_id = ?
45+
GROUP BY backend
46+
ORDER BY total_blocks DESC;
47+
""",
48+
(run_id,),
49+
)
50+
rows = cur.fetchall()
51+
if not rows:
52+
err = f"No block data found in 'blocks' table for run_id={run_id}."
53+
raise RuntimeError(err)
54+
return rows
55+
56+
57+
def plot_backend_blocks(
58+
backend_data: list[tuple[str, int]], *, run_name: str, output_path: pl.Path
59+
) -> None:
60+
"""Plot a bar chart of total blocks per backend."""
61+
backends = [b for b, _ in backend_data]
62+
totals = [t for _, t in backend_data]
63+
df = pd.DataFrame({"backend": backends, "total_blocks": totals})
64+
65+
sns.set_theme(style="whitegrid")
66+
67+
plt.figure(figsize=(8, 5))
68+
ax = sns.barplot(data=df, x="backend", y="total_blocks")
69+
70+
ax.set_xlabel("Backend")
71+
ax.set_ylabel("Total blocks over run")
72+
ax.set_title(f"Total blocks per backend in run {run_name}")
73+
74+
# Annotate bars with values (type-narrow containers to BarContainer)
75+
for c in ax.containers:
76+
if isinstance(c, mcontainer.BarContainer):
77+
ax.bar_label(c)
78+
79+
plt.tight_layout()
80+
plt.savefig(output_path, dpi=150)
81+
plt.close()
82+
83+
84+
def parse_args() -> argparse.Namespace:
85+
parser = argparse.ArgumentParser(
86+
description=(
87+
"Generate a bar chart of total blocks per backend for the latest run "
88+
"from an SQLite database."
89+
)
90+
)
91+
parser.add_argument(
92+
"-d",
93+
"--dbpath",
94+
required=True,
95+
help="Path to the SQLite database file.",
96+
)
97+
parser.add_argument(
98+
"-n",
99+
"--name",
100+
required=True,
101+
help="Name of the run (for labeling purposes).",
102+
)
103+
parser.add_argument(
104+
"-o",
105+
"--output",
106+
required=True,
107+
help="Output image filename.",
108+
)
109+
return parser.parse_args()
110+
111+
112+
def main() -> int:
113+
args = parse_args()
114+
dbpath = pl.Path(args.dbpath)
115+
output_path = pl.Path(args.output)
116+
117+
if not dbpath.exists():
118+
print(f"Error: database file '{args.dbpath}' does not exist.", file=sys.stderr)
119+
return 1
120+
121+
try:
122+
with contextlib.closing(sqlite3.connect(dbpath)) as conn, conn:
123+
run_id = get_latest_run_id(conn)
124+
backend_data = get_blocks_per_backend(conn, run_id=run_id)
125+
plot_backend_blocks(
126+
backend_data=backend_data, run_name=args.name, output_path=output_path
127+
)
128+
except (sqlite3.Error, RuntimeError) as e:
129+
print(f"Error: {e}", file=sys.stderr)
130+
return 1
131+
132+
print(f"Saved graph to {args.output}")
133+
return 0
134+
135+
136+
if __name__ == "__main__":
137+
sys.exit(main())

0 commit comments

Comments
 (0)