Skip to content

Commit d735523

Browse files
authored
explicit num_warmup, num_samples arguments for MCMC (#1040)
* request for explicit num_warmup, num_samples * force num_warmup num_samples in all tests * ping nbsphinx version * fix remaining bugs * run make format * revert docs requirement * fix wrong changes
1 parent 9404041 commit d735523

35 files changed

+165
-126
lines changed

examples/annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ def main(args):
304304

305305
mcmc = MCMC(
306306
NUTS(model),
307-
args.num_warmup,
308-
args.num_samples,
307+
num_warmup=args.num_warmup,
308+
num_samples=args.num_samples,
309309
num_chains=args.num_chains,
310310
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
311311
)

examples/baseball.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def run_inference(model, at_bats, hits, rng_key, args):
145145
kernel = SA(model)
146146
mcmc = MCMC(
147147
kernel,
148-
args.num_warmup,
149-
args.num_samples,
148+
num_warmup=args.num_warmup,
149+
num_samples=args.num_samples,
150150
num_chains=args.num_chains,
151151
progress_bar=False
152152
if ("NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar)

examples/bnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def run_inference(model, args, rng_key, X, Y, D_H):
7676
kernel = NUTS(model)
7777
mcmc = MCMC(
7878
kernel,
79-
args.num_warmup,
80-
args.num_samples,
79+
num_warmup=args.num_warmup,
80+
num_samples=args.num_samples,
8181
num_chains=args.num_chains,
8282
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
8383
)

examples/capture_recapture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def run_inference(model, capture_history, sex, rng_key, args):
288288
kernel = HMC(model)
289289
mcmc = MCMC(
290290
kernel,
291-
args.num_warmup,
292-
args.num_samples,
291+
num_warmup=args.num_warmup,
292+
num_samples=args.num_samples,
293293
num_chains=args.num_chains,
294294
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
295295
)

examples/covtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def benchmark_hmc(args, features, labels):
192192
)
193193
else:
194194
raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
195-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
195+
mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)
196196
mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",))
197197
print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
198198
mcmc.print_summary(exclude_deterministic=False)

examples/funnel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def run_inference(model, args, rng_key):
5353
kernel = NUTS(model)
5454
mcmc = MCMC(
5555
kernel,
56-
args.num_warmup,
57-
args.num_samples,
56+
num_warmup=args.num_warmup,
57+
num_samples=args.num_samples,
5858
num_chains=args.num_chains,
5959
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
6060
)

examples/gp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def run_inference(model, args, rng_key, X, Y):
8585
kernel = NUTS(model, init_strategy=init_strategy)
8686
mcmc = MCMC(
8787
kernel,
88-
args.num_warmup,
89-
args.num_samples,
88+
num_warmup=args.num_warmup,
89+
num_samples=args.num_samples,
9090
num_chains=args.num_chains,
9191
thinning=args.thinning,
9292
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,

examples/hmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def main(args):
221221
kernel = NUTS(semi_supervised_hmm)
222222
mcmc = MCMC(
223223
kernel,
224-
args.num_warmup,
225-
args.num_samples,
224+
num_warmup=args.num_warmup,
225+
num_samples=args.num_samples,
226226
num_chains=args.num_chains,
227227
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
228228
)

examples/hmm_enum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ def main(args):
326326
kernel = {"nuts": NUTS, "hmc": HMC}[args.kernel](model)
327327
mcmc = MCMC(
328328
kernel,
329-
args.num_warmup,
330-
args.num_samples,
331-
args.num_chains,
329+
num_warmup=args.num_warmup,
330+
num_samples=args.num_samples,
331+
num_chains=args.num_chains,
332332
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
333333
)
334334
mcmc.run(rng_key, sequences, lengths, args=args)

examples/neutra.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def main(args):
6464
nuts_kernel = NUTS(dual_moon_model)
6565
mcmc = MCMC(
6666
nuts_kernel,
67-
args.num_warmup,
68-
args.num_samples,
67+
num_warmup=args.num_warmup,
68+
num_samples=args.num_samples,
6969
num_chains=args.num_chains,
7070
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
7171
)
@@ -91,8 +91,8 @@ def main(args):
9191
nuts_kernel = NUTS(neutra_model)
9292
mcmc = MCMC(
9393
nuts_kernel,
94-
args.num_warmup,
95-
args.num_samples,
94+
num_warmup=args.num_warmup,
95+
num_samples=args.num_samples,
9696
num_chains=args.num_chains,
9797
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
9898
)

0 commit comments

Comments
 (0)