Skip to content

Commit 1ec988a

Browse files
committed
Fix in nnflow docs
1 parent 9c0f04c commit 1ec988a

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

book.toml

Lines changed: 0 additions & 6 deletions
This file was deleted.

docs/nf-adapt.qmd

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@ normal distribution. The flow is trained during warmup.
1717

1818
For more information about the algorithm, see the paper todo
1919

20+
Currently, a lot of time is spent on compiling various parts of the normalizing
21+
flow, and for small models this can take a large amount of the total time.
22+
Hopefully, we will be able to reduce this overhead in the future.
23+
2024
## Requirements
2125

2226
Install the optional dependencies for normalizing flow adaptation:
2327

2428
```
25-
pip install 'nutpie[flow]'
29+
pip install 'nutpie[nnflow]'
2630
```
2731

2832
If you use with PyMC, this will only work if the model is compiled using the jax
@@ -50,34 +54,39 @@ it to sample from a difficult posterior:
5054
import pymc as pm
5155
import nutpie
5256
import numpy as np
57+
import arviz
5358
5459
# Define a 100-dimensional funnel model
5560
with pm.Model() as model:
5661
log_sigma = pm.Normal("log_sigma")
57-
x = pm.Normal("x", mu=0, sigma=pm.math.exp(y / 2), shape=100)
62+
pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)
5863
5964
# Compile the model with the jax backend
6065
compiled = nutpie.compile_pymc_model(
6166
model, backend="jax", gradient_backend="jax"
6267
)
6368
```
6469

65-
If we sample this model without normalizing flow adaptation, we may encounter
66-
divergences and don't recover the actual posterior distribution:
70+
If we sample this model without normalizing flow adaptation, we will encounter
71+
convergence issues, often divergences and always low effective sample sizes:
6772

6873
```{python}
6974
# Sample without normalizing flow adaptation
70-
trace_no_nf = nutpie.sample(compiled_no_nf, seed=1)
71-
assert trace_no_nf.sample_stats.diverging.sum() > 0
75+
trace_no_nf = nutpie.sample(compiled, seed=1)
76+
assert (arviz.ess(trace_no_nf) < 100).any().to_array().any()
7277
```
7378

7479
```{python}
7580
# We can add further arguments for the normalizing flow:
76-
compiled = compiled.with_transform_adapt(num_layers=9)
81+
compiled = compiled.with_transform_adapt(
82+
num_layers=5, # Use 5 layers in the normalizing flow
83+
nn_width=32, # Use neural networks with 32 hidden units
84+
)
7785
7886
# Sample with normalizing flow adaptation
79-
trace_nf = nutpie.sample(compiled, transform_adapt=True, seed=1)
87+
trace_nf = nutpie.sample(compiled, transform_adapt=True, seed=1, chains=2, cores=1)
8088
assert trace_no_nf.sample_stats.diverging.sum() == 0
89+
assert (arviz.ess(trace_no_nf) > 500).all().to_array().all()
8190
```
8291

8392
The flow adaptation occurs during warmup, so the number of warmup draws should

0 commit comments

Comments
 (0)