Skip to content

Commit b5f7d86

Browse files
committed
add notebook runner
1 parent fc03395 commit b5f7d86

File tree

3 files changed

+281
-0
lines changed

3 files changed

+281
-0
lines changed

environment.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: pymc-examples
2+
channels:
3+
- conda-forge
4+
dependencies:
5+
- python=3.11
6+
- pymc
7+
- pymc-bart
8+
- nutpie
9+
# spatial notebooks
10+
- geopandas
11+
- folium
12+
- libpysal
13+
- rasterio
14+
- pip:
15+
- pymc-experimental
16+
- preliz
17+
- bambi
18+
- jax
19+
- papermill
20+
- joblib
21+
- jupyter
22+
- seaborn
23+
- watermark
24+
- lifelines

scripts/run_notebooks/injected.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Injected code to the top of each notebook to mock long running code."""
2+
3+
import os
4+
import numpy as np
5+
import pymc as pm
6+
import xarray as xr
7+
8+
9+
def mock_sample(*args, **kwargs):
10+
if len(args) > 0:
11+
draws = args[0]
12+
else:
13+
draws = kwargs.get("draws", 1000)
14+
random_seed = kwargs.get("random_seed", None)
15+
rng = np.random.default_rng(random_seed)
16+
model = kwargs.get("model", None)
17+
chains = kwargs.get("chains", os.cpu_count())
18+
idata = pm.sample_prior_predictive(
19+
model=model,
20+
random_seed=random_seed,
21+
samples=draws,
22+
)
23+
n_chains = chains
24+
expanded_chains = xr.DataArray(
25+
np.ones(n_chains),
26+
coords={"chain": np.arange(n_chains)},
27+
)
28+
idata.add_groups(
29+
posterior=(idata.prior.mean("chain") * expanded_chains).transpose(
30+
"chain", "draw", ...
31+
)
32+
)
33+
if "prior" in idata:
34+
del idata.prior
35+
if "prior_predictive" in idata:
36+
del idata.prior_predictive
37+
38+
# Create mock sample stats with diverging data
39+
if "sample_stats" not in idata:
40+
n_chains = chains
41+
n_draws = draws
42+
sample_stats = xr.Dataset(
43+
{
44+
"diverging": xr.DataArray(
45+
np.zeros((n_chains, n_draws), dtype=int),
46+
dims=("chain", "draw"),
47+
),
48+
"energy": xr.DataArray(
49+
rng.normal(loc=150, scale=2.5, size=(n_chains, n_draws)),
50+
dims=("chain", "draw"),
51+
),
52+
"tree_depth": xr.DataArray(
53+
rng.choice(
54+
[1, 2, 3], p=[0.01, 0.86, 0.13], size=(n_chains, n_draws)
55+
),
56+
dims=("chain", "draw"),
57+
),
58+
"acceptance_rate": xr.DataArray(
59+
rng.beta(0.5, 0.5, size=(n_chains, n_draws)),
60+
dims=("chain", "draw"),
61+
),
62+
# Different sampler
63+
"accept": xr.DataArray(
64+
rng.choice([0, 1], size=(n_chains, n_draws)),
65+
dims=("chain", "draw"),
66+
),
67+
}
68+
)
69+
idata.add_groups(sample_stats=sample_stats)
70+
71+
return idata
72+
73+
74+
pm.sample = mock_sample
75+
pm.HalfFlat = pm.HalfNormal
76+
pm.Flat = pm.Normal

scripts/run_notebooks/runner.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""Script to run all notebooks in the docs/source/notebooks directory."""
2+
3+
from argparse import ArgumentParser
4+
5+
from rich.console import Console
6+
import logging
7+
from pathlib import Path
8+
from tempfile import NamedTemporaryFile
9+
from typing import TypedDict
10+
from uuid import uuid4
11+
12+
import papermill
13+
from joblib import Parallel, delayed
14+
from nbformat.notebooknode import NotebookNode
15+
from papermill.iorw import load_notebook_node, write_ipynb
16+
17+
KERNEL_NAME: str = "python3"
18+
19+
HERE = Path(__file__).parent
20+
INJECTED_CODE_FILE = HERE / "injected.py"
21+
INJECTED_CODE = INJECTED_CODE_FILE.read_text()
22+
23+
24+
def setup_logging() -> None:
25+
logging.basicConfig(
26+
level=logging.INFO,
27+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
28+
)
29+
30+
31+
def generate_random_id() -> str:
32+
return str(uuid4())
33+
34+
35+
def inject_pymc_sample_mock_code(cells: list) -> None:
36+
cells.insert(
37+
0,
38+
NotebookNode(
39+
id=f"code-injection-{generate_random_id()}",
40+
execution_count=sum(map(ord, "Mock pm.sample")),
41+
cell_type="code",
42+
metadata={"tags": []},
43+
outputs=[],
44+
source=INJECTED_CODE,
45+
),
46+
)
47+
48+
49+
def mock_run(notebook_path: Path, i: int, total: int) -> None:
50+
nb = load_notebook_node(str(notebook_path))
51+
inject_pymc_sample_mock_code(nb.cells)
52+
with NamedTemporaryFile(suffix=".ipynb") as f:
53+
write_ipynb(nb, f.name)
54+
desc = f"({i} / {total}) Mocked {notebook_path.name}"
55+
papermill.execute_notebook(
56+
input_path=f.name,
57+
output_path=None,
58+
progress_bar=dict(desc=desc),
59+
kernel_name=KERNEL_NAME,
60+
cwd=notebook_path.parent,
61+
)
62+
63+
64+
def actual_run(notebook_path: Path, i: int, total: int) -> None:
65+
papermill.execute_notebook(
66+
input_path=notebook_path,
67+
output_path=None,
68+
kernel_name=KERNEL_NAME,
69+
progress_bar={"desc": f"({i} / {total}) Running {notebook_path.name}"},
70+
cwd=notebook_path.parent,
71+
)
72+
73+
74+
class NotebookFailure(TypedDict):
75+
notebook_path: Path
76+
error: str
77+
78+
79+
def run_notebook(
80+
notebook_path: Path,
81+
i: int,
82+
total: int,
83+
mock: bool = True,
84+
) -> NotebookFailure | None:
85+
logging.info(f"Running notebook: {notebook_path.name}")
86+
run = mock_run if mock else actual_run
87+
88+
try:
89+
run(notebook_path, i=i, total=total)
90+
except Exception as e:
91+
logging.error(
92+
f"{e.__class__.__name__} encountered running notebook: {str(notebook_path)}"
93+
)
94+
return NotebookFailure(notebook_path=notebook_path, error=str(e))
95+
else:
96+
return
97+
98+
99+
class RunParams(TypedDict):
100+
notebook_path: Path
101+
mock: bool
102+
i: int
103+
total: int
104+
105+
106+
def run_parameters(notebook_paths: list[Path], mock: bool = True) -> list[RunParams]:
107+
def to_mock(notebook_path: Path, i: int) -> RunParams:
108+
return RunParams(
109+
notebook_path=notebook_path, mock=mock, i=i, total=len(notebook_paths)
110+
)
111+
112+
return [
113+
to_mock(notebook_path, i=i)
114+
for i, notebook_path in enumerate(notebook_paths, start=1)
115+
]
116+
117+
118+
def main(notebooks_to_run: list[Path], mock: bool = True) -> None:
119+
console = Console()
120+
errors: list[NotebookFailure]
121+
setup_logging()
122+
logging.info("Starting notebook runner")
123+
logging.info(f"Running {len(notebooks_to_run)} notebook(s).")
124+
results = Parallel(n_jobs=-1)(
125+
delayed(run_notebook)(**run_params)
126+
for run_params in run_parameters(notebooks_to_run, mock=mock)
127+
)
128+
errors = [result for result in results if result is not None]
129+
130+
if not errors:
131+
logging.info("Notebooks run successfully!")
132+
return
133+
134+
for error in errors:
135+
console.rule(f"[bold red]Error running {error['notebook_path']}[/bold red]")
136+
console.print(error["error"])
137+
138+
logging.error(f"{len(errors)} / {len(notebooks_to_run)} notebooks failed")
139+
140+
141+
def parse_args():
142+
parser = ArgumentParser()
143+
parser.add_argument(
144+
"--notebooks",
145+
nargs="+",
146+
help="List of notebooks to run. If not provided, all notebooks will be run.",
147+
)
148+
mock_group = parser.add_mutually_exclusive_group()
149+
mock_group.add_argument(
150+
"--mock",
151+
action="store_true",
152+
help="Run notebooks with mock code",
153+
dest="mock",
154+
)
155+
mock_group.add_argument(
156+
"--no-mock",
157+
action="store_false",
158+
help="Run notebooks without mock code",
159+
dest="mock",
160+
)
161+
parser.set_defaults(mock=True)
162+
args = parser.parse_args()
163+
164+
notebooks_to_run = []
165+
notebooks = args.notebooks
166+
notebooks = [Path(notebook) for notebook in notebooks]
167+
for notebook in notebooks:
168+
if notebook.is_dir():
169+
notebooks_to_run.extend(notebook.glob("*.ipynb"))
170+
notebooks_to_run.extend(notebook.glob("*/*.ipynb"))
171+
else:
172+
notebooks_to_run.append(notebook)
173+
174+
args.notebooks = notebooks_to_run
175+
176+
return args
177+
178+
179+
if __name__ == "__main__":
180+
args = parse_args()
181+
main(args.notebooks, mock=args.mock)

0 commit comments

Comments
 (0)