|
1 | 1 | import inspect |
| 2 | +import pickle |
| 3 | +from pathlib import Path |
2 | 4 |
|
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
3 | 7 | import pandas as pd |
4 | 8 |
|
5 | 9 | import numpyro |
6 | 10 | from numpyro import infer |
7 | 11 | from numpyro import distributions as dist |
8 | 12 |
|
9 | 13 | from .. import prior |
| 14 | +from ..pulsar import save_chain |
10 | 15 |
|
11 | 16 |
|
12 | 17 | def makemodel_transformed(mylogl, transform=prior.makelogtransform_uniform, priordict={}): |
@@ -48,3 +53,100 @@ def makesampler_nuts(numpyro_model, num_warmup=512, num_samples=1024, num_chains |
48 | 53 | sampler.to_df = lambda: numpyro_model.to_df(sampler.get_samples()) |
49 | 54 |
|
50 | 55 | return sampler |
| 56 | + |
| 57 | +def run_nuts_with_checkpoints( |
| 58 | + sampler, |
| 59 | + num_samples_per_checkpoint, |
| 60 | + rng_key, |
| 61 | + outdir="chains", |
| 62 | + resume=False, |
| 63 | +): |
| 64 | + """Run NumPyro MCMC and save checkpoints. |
| 65 | +
|
| 66 | + This function performs multiple iterations of MCMC sampling, saving checkpoints |
| 67 | + after each iteration. It saves samples to feather files and the NumPyro MCMC |
| 68 | + state to JSON. |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + sampler : numpyro.infer.MCMC |
| 73 | + A NumPyro MCMC sampler object. |
| 74 | + num_samples_per_checkpoint : int |
| 75 | + The number of samples to save in each checkpoint. |
| 76 | + rng_key : jax.random.PRNGKey |
| 77 | + The random number generator key for JAX. |
| 78 | + outdir : str | Path |
| 79 | + The directory for output files. |
| 80 | + resume : bool |
| 81 | + Whether to look for a state to resume from. |
| 82 | +
|
| 83 | + Returns |
| 84 | + ------- |
| 85 | + None |
| 86 | + This function doesn't return any value but saves the results to disk. |
| 87 | +
|
| 88 | + Side Effects |
| 89 | + ------------ |
| 90 | + - Runs the MCMC sampler for the number of iterations required to reach the total sample number. |
| 91 | + - Saves samples data to feather files after each iteration. |
| 92 | + - Writes the NumPyro sampler state to a pickle file after each iteration. |
| 93 | +
|
| 94 | + Example |
| 95 | + ------- |
| 96 | + >>> import discovery.samplers.numpyro as ds_numpyro |
| 97 | + >>> # Assume `model` is configured |
| 98 | + >>> npsampler = ds_numpyro.makesampler_nuts(model, num_samples =100, num_warmup=50) |
| 99 | + >>> ds_numpyro.run_nuts_with_checkpoints(npsampler, 10, jax.random.key(42)) |
| 100 | +
|
| 101 | + """ |
| 102 | + # convert to pathlib object |
| 103 | + # make directory if it doesn't exist |
| 104 | + if not isinstance(outdir, Path): |
| 105 | + outdir = Path(outdir) |
| 106 | + outdir.mkdir(exist_ok=True, parents=True) |
| 107 | + |
| 108 | + samples_file = outdir / "numpyro-samples.feather" |
| 109 | + checkpoint_file = outdir / "numpyro-checkpoint.pickle" |
| 110 | + |
| 111 | + if checkpoint_file.is_file() and samples_file.is_file() and resume: |
| 112 | + df = pd.read_feather(samples_file) |
| 113 | + num_samples_saved = df.shape[0] |
| 114 | + |
| 115 | + with checkpoint_file.open("rb") as f: |
| 116 | + checkpoint = pickle.load(f) |
| 117 | + |
| 118 | + total_sample_num = sampler.num_samples - num_samples_saved |
| 119 | + |
| 120 | + sampler.post_warmup_state = checkpoint |
| 121 | + |
| 122 | + else: |
| 123 | + df = None |
| 124 | + num_samples_saved = 0 |
| 125 | + total_sample_num = sampler.num_samples |
| 126 | + |
| 127 | + num_checkpoints = int(jnp.ceil(total_sample_num / num_samples_per_checkpoint)) |
| 128 | + remainder_samples = int(total_sample_num % num_samples_per_checkpoint) |
| 129 | + |
| 130 | + for checkpoint in range(num_checkpoints): |
| 131 | + if checkpoint == 0: |
| 132 | + sampler.num_samples = num_samples_per_checkpoint |
| 133 | + sampler._set_collection_params() # Need this to update num_samples |
| 134 | + elif checkpoint == num_checkpoints - 1: |
| 135 | + # We won't need to update the collection params because we've set the post warmup state, |
| 136 | + # and that accomplishes the same goal. |
| 137 | + sampler.num_samples = remainder_samples if remainder_samples != 0 else num_samples_per_checkpoint |
| 138 | + |
| 139 | + sampler.run(rng_key) |
| 140 | + |
| 141 | + df_new = sampler.to_df() |
| 142 | + |
| 143 | + df = pd.concat([df, df_new]) if df is not None else df_new |
| 144 | + |
| 145 | + save_chain(df, samples_file) |
| 146 | + |
| 147 | + with checkpoint_file.open("wb") as f: |
| 148 | + pickle.dump(sampler.last_state, f) |
| 149 | + |
| 150 | + sampler.post_warmup_state = sampler.last_state |
| 151 | + |
| 152 | + rng_key, _ = jax.random.split(rng_key) |
0 commit comments