Skip to content

Commit 4fee32e

Browse files
authored
Merge pull request #5 from martiningram/add_mcmc_comparison
Add code for MCMC comparison
2 parents 2896a5d + b8cbd36 commit 4fee32e

15 files changed

+2978
-0
lines changed

mcmc_comparison/Compare runtimes -- multi-run.ipynb

Lines changed: 1138 additions & 0 deletions
Large diffs are not rendered by default.

mcmc_comparison/Compare runtimes.ipynb

Lines changed: 1056 additions & 0 deletions
Large diffs are not rendered by default.

mcmc_comparison/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# MCMC benchmarks
2+
3+
![ESS](images/ess_values.png)
4+
5+
This code compares Stan, PyMC, and PyMC + JAX numpyro sampler on a model for
6+
tennis. It accompanies the blog post available
7+
[here](https://martiningram.github.io/mcmc-comparison/).
8+
9+
This is a copy of the repository here:
10+
https://github.com/martiningram/mcmc_runtime_comparison.
11+
12+
### Setup notes
13+
14+
This benchmark uses Jeff Sackmann's tennis data. You can obtain it as follows:
15+
16+
```
17+
git clone https://github.com/JeffSackmann/tennis_atp.git
18+
19+
# If you want to reproduce the results in the blog post, check out this commit:
20+
cd tennis_atp && git checkout 89c20f1ef56f69db1b73b5782671ee85203b068a
21+
```
22+
23+
Requirements that can be installed using pip are listed in
24+
`requirements.txt`. Please install these first.
25+
26+
Once these are done, here are the steps I followed to setup PyMC v4 with JAX support:
27+
28+
* PyMC v4 installed using the instructions here: https://github.com/pymc-devs/pymc/wiki/Installation-Guide-(Linux)#pymc-v4-installation
29+
* `blackjax` and `numpyro` were also installed using those instructions.
30+
31+
To run the Stan code, it's best to install `cmdstanpy`. Instructions for
32+
installing it can be found [here](https://mc-stan.org/cmdstanpy/installation.html).
33+
34+
### How to run
35+
36+
It's easiest to run the benchmarks using the `fit_all.sh` script. Make sure to
37+
first edit the `target_dir` variable in it and amend it to a directory that
38+
makes sense for you. All the model runs will be stored in it under
39+
subdirectories.
40+
41+
Once benchmarks have been run, you can analyse the results and make plots using
42+
the `Compare_runtimes.ipynb` notebook.
43+
44+
If you run into any problems, please raise an issue!

mcmc_comparison/fetch_data.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from sackmann import get_data
2+
import pymc as pm
3+
from sklearn.preprocessing import LabelEncoder
4+
import numpy as np
5+
6+
7+
def create_arrays(
8+
start_year=1960,
9+
data_dir="./tennis_atp",
10+
include_qualifying_and_challengers=False,
11+
include_futures=False,
12+
):
13+
14+
df = get_data(
15+
data_dir,
16+
include_qualifying_and_challengers=include_qualifying_and_challengers,
17+
include_futures=include_futures,
18+
)
19+
20+
rel_df = df[df["tourney_date"].dt.year >= start_year]
21+
22+
encoder = LabelEncoder()
23+
24+
encoder.fit(
25+
rel_df["winner_name"].values.tolist() + rel_df["loser_name"].values.tolist()
26+
)
27+
28+
winner_ids = encoder.transform(rel_df["winner_name"])
29+
loser_ids = encoder.transform(rel_df["loser_name"])
30+
31+
return {
32+
"winner_ids": winner_ids,
33+
"loser_ids": loser_ids,
34+
"player_encoder": encoder,
35+
}
36+
37+
38+
def get_pymc_model(start_year=1960, data_dir="./tennis_atp"):
39+
40+
arrays = create_arrays(start_year=start_year, data_dir=data_dir)
41+
42+
n_players = len(arrays["player_encoder"].classes_)
43+
44+
winner_ids = arrays["winner_ids"]
45+
loser_ids = arrays["loser_ids"]
46+
47+
with pm.Model() as model:
48+
49+
player_sd = pm.HalfNormal("player_sd", sigma=1.0)
50+
51+
player_skills_raw = pm.Normal(
52+
"player_skills_raw", 0.0, sigma=1.0, shape=(n_players,)
53+
)
54+
55+
player_skills = pm.Deterministic("player_skills", player_skills_raw * player_sd)
56+
logit_skills = player_skills[winner_ids] - player_skills[loser_ids]
57+
58+
lik = pm.Bernoulli(
59+
"win_lik", logit_p=logit_skills, observed=np.ones(winner_ids.shape[0])
60+
)
61+
62+
return model

mcmc_comparison/fit_all.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Modify as desired; will be created if it does not exist
2+
base_target_dir="/media/martin/External Drive/projects/pymc_vs_stan/multi_run/fits"
3+
n_runs=10
4+
5+
for cur_run in `seq 1 $n_runs`; do
6+
7+
echo "Running $cur_run"
8+
9+
random_seed=$cur_run
10+
target_dir="$base_target_dir"/"$cur_run"
11+
12+
for start_year in 2020 2019 2015 2010 2000 1990 1980 1968; do
13+
echo "Fitting $start_year"
14+
echo "PyMC JAX GPU parallel" && python fit_pymc_numpyro.py $start_year gpu parallel "$target_dir" $random_seed
15+
echo "PyMC JAX GPU vectorized" && python fit_pymc_numpyro.py $start_year gpu vectorized "$target_dir" $random_seed
16+
echo "PyMC JAX CPU parallel" && python fit_pymc_numpyro.py $start_year cpu parallel "$target_dir" $random_seed
17+
echo "PyMC JAX CPU vectorized" && python fit_pymc_numpyro.py $start_year cpu vectorized "$target_dir" $random_seed
18+
echo "PyMC BlackJAX CPU" && python fit_pymc_blackjax.py $start_year cpu "$target_dir" $random_seed parallel
19+
echo "PyMC BlackJAX GPU" && python fit_pymc_blackjax.py $start_year gpu "$target_dir" $random_seed vectorized
20+
echo "PyMC" && python fit_pymc.py $start_year "$target_dir" $random_seed
21+
echo "cmdstanpy" && python fit_cmdstanpy.py $start_year "$target_dir" $random_seed
22+
done
23+
24+
done

mcmc_comparison/fit_cmdstanpy.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import sys
2+
import os
3+
from fetch_data import create_arrays
4+
from time import time
5+
import numpy as np
6+
import arviz as az
7+
from cmdstanpy import CmdStanModel
8+
9+
start_year = int(sys.argv[1])
10+
target_dir = sys.argv[2] + "/cmdstanpy"
11+
seed = int(sys.argv[3])
12+
13+
os.makedirs(target_dir, exist_ok=True)
14+
15+
arrays = create_arrays(start_year=start_year)
16+
17+
start_time = time()
18+
19+
winner_ids = arrays["winner_ids"]
20+
loser_ids = arrays["loser_ids"]
21+
player_encoder = arrays["player_encoder"]
22+
23+
stan_data = {
24+
"n_matches": len(winner_ids),
25+
"n_players": len(player_encoder.classes_),
26+
"winner_ids": winner_ids + 1,
27+
"loser_ids": loser_ids + 1,
28+
}
29+
30+
model = CmdStanModel(stan_file="stan_model_optimised.stan")
31+
model.compile()
32+
33+
fit = model.sample(data=stan_data, parallel_chains=4, seed=seed)
34+
35+
runtime = time() - start_time
36+
37+
arviz_version = az.from_cmdstanpy(posterior=fit)
38+
39+
az.to_netcdf(arviz_version, os.path.join(target_dir, f"samples_{start_year}.netcdf"))
40+
print(runtime, file=open(os.path.join(target_dir, f"runtime_{start_year}.txt"), "w"))

mcmc_comparison/fit_pymc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import sys
2+
import os
3+
from fetch_data import get_pymc_model
4+
from time import time
5+
import pymc as pm
6+
7+
start_year = int(sys.argv[1])
8+
target_dir = sys.argv[2] + "/pymc"
9+
seed = int(sys.argv[3])
10+
11+
os.makedirs(target_dir, exist_ok=True)
12+
13+
model = get_pymc_model(start_year=start_year)
14+
15+
start_time = time()
16+
17+
with model:
18+
hierarchical_trace = pm.sample(
19+
1000,
20+
tune=1000,
21+
return_inferencedata=True,
22+
compute_convergence_checks=False,
23+
random_seed=seed,
24+
)
25+
26+
runtime = time() - start_time
27+
28+
hierarchical_trace.to_netcdf(os.path.join(target_dir, f"samples_{start_year}.netcdf"))
29+
print(runtime, file=open(os.path.join(target_dir, f"runtime_{start_year}.txt"), "w"))
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
import os
3+
import pymc.sampling_jax
4+
from fetch_data import get_pymc_model
5+
from time import time
6+
import pymc as pm
7+
8+
start_year = int(sys.argv[1])
9+
platform = sys.argv[2]
10+
base_dir = sys.argv[3]
11+
seed = int(sys.argv[4])
12+
chain_method = sys.argv[5]
13+
14+
assert platform in ["cpu", "gpu"]
15+
16+
if platform == "cpu":
17+
# Disable GPU
18+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
19+
20+
target_dir = f"{base_dir}/pymc_blackjax_{platform}_{chain_method}"
21+
22+
os.makedirs(target_dir, exist_ok=True)
23+
24+
model = get_pymc_model(start_year=start_year)
25+
26+
start_time = time()
27+
28+
with model:
29+
# No progress bar?
30+
hierarchical_trace = pymc.sampling_jax.sample_blackjax_nuts(
31+
random_seed=seed, chain_method=chain_method,
32+
idata_kwargs={'log_likelihood': False})
33+
34+
runtime = time() - start_time
35+
36+
hierarchical_trace.to_netcdf(os.path.join(target_dir, f"samples_{start_year}.netcdf"))
37+
print(runtime, file=open(os.path.join(target_dir, f"runtime_{start_year}.txt"), "w"))
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
import os
3+
import pymc.sampling_jax
4+
from fetch_data import get_pymc_model
5+
from time import time
6+
import pymc as pm
7+
8+
start_year = int(sys.argv[1])
9+
platform = sys.argv[2]
10+
chain_method = sys.argv[3]
11+
base_dir = sys.argv[4]
12+
seed = int(sys.argv[5])
13+
14+
assert platform in ["cpu", "gpu"]
15+
16+
if platform == "cpu":
17+
# Disable GPU
18+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
19+
20+
target_dir = f"{base_dir}/pymc_numpyro_{platform}_{chain_method}"
21+
22+
os.makedirs(target_dir, exist_ok=True)
23+
24+
model = get_pymc_model(start_year=start_year)
25+
26+
start_time = time()
27+
28+
with model:
29+
hierarchical_trace = pymc.sampling_jax.sample_numpyro_nuts(
30+
chain_method=chain_method, random_seed=seed,
31+
idata_kwargs={'log_likelihood': False}
32+
)
33+
34+
runtime = time() - start_time
35+
36+
hierarchical_trace.to_netcdf(os.path.join(target_dir, f"samples_{start_year}.netcdf"))
37+
print(runtime, file=open(os.path.join(target_dir, f"runtime_{start_year}.txt"), "w"))

mcmc_comparison/fit_stan.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import sys
2+
import os
3+
from fetch_data import create_arrays
4+
from time import time
5+
import stan
6+
import numpy as np
7+
import arviz as az
8+
9+
start_year = int(sys.argv[1])
10+
target_dir = sys.argv[2] + "/stan"
11+
12+
os.makedirs(target_dir, exist_ok=True)
13+
14+
arrays = create_arrays(start_year=start_year)
15+
16+
start_time = time()
17+
18+
winner_ids = arrays["winner_ids"]
19+
loser_ids = arrays["loser_ids"]
20+
player_encoder = arrays["player_encoder"]
21+
22+
stan_data = {
23+
"n_matches": len(winner_ids),
24+
"n_players": len(player_encoder.classes_),
25+
"winner_ids": winner_ids + 1,
26+
"loser_ids": loser_ids + 1,
27+
}
28+
29+
with open("./stan_model.stan", "r") as f:
30+
posterior = stan.build(program_code=f.read(), data=stan_data)
31+
32+
fit = posterior.sample(num_chains=4, num_samples=1000)
33+
34+
runtime = time() - start_time
35+
36+
arviz_version = az.from_pystan(fit)
37+
38+
az.to_netcdf(arviz_version, os.path.join(target_dir, f"samples_{start_year}.netcdf"))
39+
print(runtime, file=open(os.path.join(target_dir, f"runtime_{start_year}.txt"), "w"))

0 commit comments

Comments
 (0)