Skip to content

Commit cea28b2

Browse files
committed
feat: add numpyro checkpointing
also allows users to resume sampling
1 parent a30bac9 commit cea28b2

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

src/discovery/samplers/numpyro.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import inspect
2+
import pickle
3+
from pathlib import Path
24

5+
import jax
6+
import jax.numpy as jnp
37
import pandas as pd
48

59
import numpyro
610
from numpyro import infer
711
from numpyro import distributions as dist
812

913
from .. import prior
14+
from ..pulsar import save_chain
1015

1116

1217
def makemodel_transformed(mylogl, transform=prior.makelogtransform_uniform, priordict={}):
@@ -48,3 +53,100 @@ def makesampler_nuts(numpyro_model, num_warmup=512, num_samples=1024, num_chains
4853
sampler.to_df = lambda: numpyro_model.to_df(sampler.get_samples())
4954

5055
return sampler
56+
57+
def run_nuts_with_checkpoints(
58+
sampler,
59+
num_samples_per_checkpoint,
60+
rng_key,
61+
outdir="chains",
62+
resume=False,
63+
):
64+
"""Run NumPyro MCMC and save checkpoints.
65+
66+
This function performs multiple iterations of MCMC sampling, saving checkpoints
67+
after each iteration. It saves samples to feather files and the NumPyro MCMC
68+
state to JSON.
69+
70+
Parameters
71+
----------
72+
sampler : numpyro.infer.MCMC
73+
A NumPyro MCMC sampler object.
74+
num_samples_per_checkpoint : int
75+
The number of samples to save in each checkpoint.
76+
rng_key : jax.random.PRNGKey
77+
The random number generator key for JAX.
78+
outdir : str | Path
79+
The directory for output files.
80+
resume : bool
81+
Whether to look for a state to resume from.
82+
83+
Returns
84+
-------
85+
None
86+
This function doesn't return any value but saves the results to disk.
87+
88+
Side Effects
89+
------------
90+
- Runs the MCMC sampler for the number of iterations required to reach the total sample number.
91+
- Saves samples data to feather files after each iteration.
92+
- Writes the NumPyro sampler state to a pickle file after each iteration.
93+
94+
Example
95+
-------
96+
>>> import discovery.samplers.numpyro as ds_numpyro
97+
>>> # Assume `model` is configured
98+
>>> npsampler = ds_numpyro.makesampler_nuts(model, num_samples =100, num_warmup=50)
99+
>>> ds_numpyro.run_nuts_with_checkpoints(npsampler, 10, jax.random.key(42))
100+
101+
"""
102+
# convert to pathlib object
103+
# make directory if it doesn't exist
104+
if not isinstance(outdir, Path):
105+
outdir = Path(outdir)
106+
outdir.mkdir(exist_ok=True, parents=True)
107+
108+
samples_file = outdir / "numpyro-samples.feather"
109+
checkpoint_file = outdir / "numpyro-checkpoint.pickle"
110+
111+
if checkpoint_file.is_file() and samples_file.is_file() and resume:
112+
df = pd.read_feather(samples_file)
113+
num_samples_saved = df.shape[0]
114+
115+
with checkpoint_file.open("rb") as f:
116+
checkpoint = pickle.load(f)
117+
118+
total_sample_num = sampler.num_samples - num_samples_saved
119+
120+
sampler.post_warmup_state = checkpoint
121+
122+
else:
123+
df = None
124+
num_samples_saved = 0
125+
total_sample_num = sampler.num_samples
126+
127+
num_checkpoints = int(jnp.ceil(total_sample_num / num_samples_per_checkpoint))
128+
remainder_samples = int(total_sample_num % num_samples_per_checkpoint)
129+
130+
for checkpoint in range(num_checkpoints):
131+
if checkpoint == 0:
132+
sampler.num_samples = num_samples_per_checkpoint
133+
sampler._set_collection_params() # Need this to update num_samples
134+
elif checkpoint == num_checkpoints - 1:
135+
# We won't need to update the collection params because we've set the post warmup state,
136+
# and that accomplishes the same goal.
137+
sampler.num_samples = remainder_samples if remainder_samples != 0 else num_samples_per_checkpoint
138+
139+
sampler.run(rng_key)
140+
141+
df_new = sampler.to_df()
142+
143+
df = pd.concat([df, df_new]) if df is not None else df_new
144+
145+
save_chain(df, samples_file)
146+
147+
with checkpoint_file.open("wb") as f:
148+
pickle.dump(sampler.last_state, f)
149+
150+
sampler.post_warmup_state = sampler.last_state
151+
152+
rng_key, _ = jax.random.split(rng_key)

0 commit comments

Comments
 (0)