@@ -17,12 +17,16 @@ normal distribution. The flow is trained during warmup.
1717
1818For more information about the algorithm, see the paper todo
1919
20+ Currently, a lot of time is spent on compiling various parts of the normalizing
21+ flow, and for small models this can take a large amount of the total time.
22+ Hopefully, we will be able to reduce this overhead in the future.
23+
2024## Requirements
2125
2226Install the optional dependencies for normalizing flow adaptation:
2327
2428```
25- pip install 'nutpie[flow ]'
29+ pip install 'nutpie[nnflow ]'
2630```
2731
2832If you use with PyMC, this will only work if the model is compiled using the jax
@@ -50,34 +54,39 @@ it to sample from a difficult posterior:
5054import pymc as pm
5155import nutpie
5256import numpy as np
57+ import arviz
5358
5459# Define a 100-dimensional funnel model
5560with pm.Model() as model:
5661 log_sigma = pm.Normal("log_sigma")
57- x = pm.Normal("x", mu=0, sigma=pm.math.exp(y / 2), shape=100)
62+ pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)
5863
5964# Compile the model with the jax backend
6065compiled = nutpie.compile_pymc_model(
6166 model, backend="jax", gradient_backend="jax"
6267)
6368```
6469
65- If we sample this model without normalizing flow adaptation, we may encounter
66- divergences and don't recover the actual posterior distribution :
70+ If we sample this model without normalizing flow adaptation, we will encounter
71+ convergence issues, often divergences and always low effective sample sizes :
6772
6873``` {python}
6974# Sample without normalizing flow adaptation
70- trace_no_nf = nutpie.sample(compiled_no_nf , seed=1)
71- assert trace_no_nf.sample_stats.diverging.sum() > 0
75+ trace_no_nf = nutpie.sample(compiled , seed=1)
76+ assert (arviz.ess( trace_no_nf) < 100).any().to_array().any()
7277```
7378
7479``` {python}
7580# We can add further arguments for the normalizing flow:
76- compiled = compiled.with_transform_adapt(num_layers=9)
81+ compiled = compiled.with_transform_adapt(
82+ num_layers=5, # Use 5 layers in the normalizing flow
83+ nn_width=32, # Use neural networks with 32 hidden units
84+ )
7785
7886# Sample with normalizing flow adaptation
79- trace_nf = nutpie.sample(compiled, transform_adapt=True, seed=1)
87+ trace_nf = nutpie.sample(compiled, transform_adapt=True, seed=1, chains=2, cores=1 )
8088assert trace_no_nf.sample_stats.diverging.sum() == 0
89+ assert (arviz.ess(trace_no_nf) > 500).all().to_array().all()
8190```
8291
8392The flow adaptation occurs during warmup, so the number of warmup draws should
0 commit comments