Skip to content

Commit 2d5b9c5

Browse files
committed
update docs
1 parent b8b229b commit 2d5b9c5

File tree

14 files changed

+448
-239
lines changed

14 files changed

+448
-239
lines changed

docs/make.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ makedocs(;
1515
format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"),
1616
pages=[
1717
"Home" => "index.md",
18+
"General usage" => "usage.md",
1819
"API" => "api.md",
19-
"Example" => "example.md",
20+
"Example" => [
21+
"Planar Flow" => "PlanarFlow.md",
22+
"RealNVP" => "RealNVP.md",
23+
"Neural Spline Flow" => "NSF.md",
24+
],
2025
"Customize your own flow layer" => "customized_layer.md",
2126
],
2227
checkdocs=:exports,

docs/src/NSF.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Demo of NSF on 2D Banana Distribution
2+
3+
```julia
4+
using Random, Distributions, LinearAlgebra
5+
using Functors
6+
using Optimisers, ADTypes
7+
using Zygote
8+
using NormalizingFlows
9+
10+
11+
target = Banana(2, one(T), 100one(T))
12+
logp = Base.Fix1(logpdf, target)
13+
14+
######################################
15+
# learn the target using Neural Spline Flow
16+
######################################
17+
@leaf MvNormal
18+
q0 = MvNormal(zeros(T, 2), I)
19+
20+
21+
flow = nsf(q0; paramtype=T)
22+
flow_untrained = deepcopy(flow)
23+
######################################
24+
# start training
25+
######################################
26+
sample_per_iter = 64
27+
28+
# callback function to log training progress
29+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
30+
# nsf only supports AutoZygote
31+
adtype = ADTypes.AutoZygote()
32+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
33+
flow_trained, stats, _ = train_flow(
34+
elbo_batch,
35+
flow,
36+
logp,
37+
sample_per_iter;
38+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
39+
optimiser=Optimisers.Adam(1e-4),
40+
ADbackend=adtype,
41+
show_progress=true,
42+
callback=cb,
43+
hasconverged=checkconv,
44+
)
45+
θ, re = Optimisers.destructure(flow_trained)
46+
losses = map(x -> x.loss, stats)
47+
```

docs/src/PlanarFlow.md

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Planar Flow on a 2D Banana Distribution
2+
3+
This example demonstrates learning a synthetic 2D banana distribution with a planar normalizing flow [^RM2015] by maximizing the Evidence Lower BOund (ELBO).
4+
5+
The two required ingredients are:
6+
7+
- A log-density function `logp` for the target distribution.
8+
- A parametrised invertible transformation (the planar flow) applied to a simple base distribution.
9+
10+
## Target Distribution
11+
12+
The banana target used here is defined in `example/targets/banana.jl` (see source for details):
13+
14+
```julia
15+
using Random, Distributions
16+
Random.seed!(123)
17+
18+
target = Banana(2, 1.0, 10.0) # (dimension, nonlinearity, scale)
19+
logp = Base.Fix1(logpdf, target)
20+
```
21+
22+
You can visualise its contour and samples (figure shipped as `banana.png`).
23+
24+
![Banana](banana.png)
25+
26+
## Planar Flow
27+
28+
A planar flow of length N applies a sequence of planar layers to a base distribution q₀:
29+
30+
```math
31+
T_{n,\theta_n}(x) = x + u_n \tanh(w_n^T x + b_n), \qquad n = 1,\ldots,N.
32+
```
33+
34+
Parameters θₙ = (uₙ, wₙ, bₙ) are learned. `Bijectors.jl` provides `PlanarLayer`.
35+
36+
```julia
37+
using Bijectors
38+
using Functors # for @leaf
39+
40+
function create_planar_flow(n_layers::Int, q₀)
41+
d = length(q₀)
42+
Ls = [PlanarLayer(d) for _ in 1:n_layers]
43+
ts = reduce(, Ls) # alternatively: FunctionChains.fchain(Ls)
44+
return transformed(q₀, ts)
45+
end
46+
47+
@leaf MvNormal # prevent updating base distribution parameters
48+
q₀ = MvNormal(zeros(2), ones(2))
49+
flow = create_planar_flow(10, q₀)
50+
flow_untrained = deepcopy(flow) # keep copy for comparison
51+
```
52+
53+
If you build *many* layers (e.g. > ~30) you may reduce compilation time by using `FunctionChains.jl`:
54+
55+
```julia
56+
# uncomment the following lines to use FunctionChains
57+
# using FunctionChains
58+
# ts = fchain([PlanarLayer(d) for _ in 1:n_layers])
59+
```
60+
See [this comment](https://github.com/TuringLang/NormalizingFlows.jl/blob/8f4371d48228adf368d851e221af076ff929f1cf/src/NormalizingFlows.jl#L52)
61+
for how the compilation time might be a concern.
62+
63+
## Training the Flow
64+
65+
We maximize the ELBO (here using the minibatch estimator `elbo_batch`) with the generic `train_flow` interface.
66+
67+
```julia
68+
using NormalizingFlows
69+
using ADTypes, Optimisers
70+
using Mooncake
71+
72+
sample_per_iter = 32
73+
adtype = ADTypes.AutoMooncake(; config=Mooncake.Config()) # try AutoZygote() / AutoForwardDiff() / etc.
74+
# optional: callback function to track the batch size per iteration and the AD backend used
75+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter, ad=adtype)
76+
# optional: defined stopping criteria when the gradient norm is less than 1e-3
77+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3
78+
79+
flow_trained, stats, _ = train_flow(
80+
elbo_batch,
81+
flow,
82+
logp,
83+
sample_per_iter;
84+
max_iters = 20_000,
85+
optimiser = Optimisers.Adam(1e-2),
86+
ADbackend = adtype,
87+
callback = cb,
88+
hasconverged = checkconv,
89+
show_progress = false,
90+
)
91+
92+
losses = map(x -> x.loss, stats)
93+
```
94+
95+
Plot the losses (negative ELBO):
96+
97+
```julia
98+
using Plots
99+
plot(losses; xlabel = "iteration", ylabel = "negative ELBO", label = "", lw = 2)
100+
```
101+
102+
![elbo](elbo.png)
103+
104+
## Evaluating the Trained Flow
105+
106+
The trained flow is a `Bijectors.TransformedDistribution`, so we can call `rand` to draw iid samples and call `logpdf` to evaluate the log-density function of the flow.
107+
See [documentation of `Bijectors.jl`](https://turinglang.org/Bijectors.jl/dev/distributions/) for details.
108+
```julia
109+
n_samples = 1_000
110+
samples_trained = rand(flow_trained, n_samples)
111+
samples_untrained = rand(flow_untrained, n_samples)
112+
samples_true = rand(target, n_samples)
113+
```
114+
115+
Simple visual comparison:
116+
117+
```julia
118+
using Plots
119+
scatter(samples_true[1, :], samples_true[2, :]; label="Target", ms=2, alpha=0.5)
120+
scatter!(samples_untrained[1, :], samples_untrained[2, :]; label="Untrained", ms=2, alpha=0.5)
121+
scatter!(samples_trained[1, :], samples_trained[2, :]; label="Trained", ms=2, alpha=0.5)
122+
plot!(title = "Planar Flow: Before vs After Training", xlabel = "x₁", ylabel = "x₂", legend = :topleft)
123+
```
124+
125+
![compare](comparison.png)
126+
127+
## Notes
128+
129+
- Use `elbo` instead of `elbo_batch` for a single-sample estimator.
130+
- Switch AD backends by changing `adtype` (see `ADTypes.jl`).
131+
- Marking the base distribution with `@leaf` prevents its parameters from being updated during training.
132+
133+
## Reference
134+
135+
[^RM2015]: Rezende, D. & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML.

docs/src/RealNVP.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Demo of RealNVP on 2D Banana Distribution
2+
3+
```julia
4+
using Random, Distributions, LinearAlgebra
5+
using Functors
6+
using Optimisers, ADTypes
7+
using Mooncake
8+
using NormalizingFlows
9+
10+
11+
target = Banana(2, one(T), 100one(T))
12+
logp = Base.Fix1(logpdf, target)
13+
14+
######################################
15+
# set up the RealNVP
16+
######################################
17+
@leaf MvNormal
18+
q0 = MvNormal(zeros(T, 2), I)
19+
20+
d = 2
21+
hdims = [16, 16]
22+
nlayers = 3
23+
24+
# use NormalizingFlows.realnvp to create a RealNVP flow
25+
flow = realnvp(q0, hdims, nlayers; paramtype=T)
26+
flow_untrained = deepcopy(flow)
27+
28+
29+
######################################
30+
# start training
31+
######################################
32+
sample_per_iter = 16
33+
34+
# callback function to log training progress
35+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
36+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
37+
38+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
39+
flow_trained, stats, _ = train_flow(
40+
rng,
41+
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
42+
flow,
43+
logp,
44+
sample_per_iter;
45+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
46+
optimiser=Optimisers.Adam(5e-4),
47+
ADbackend=adtype,
48+
show_progress=true,
49+
callback=cb,
50+
hasconverged=checkconv,
51+
)
52+
θ, re = Optimisers.destructure(flow_trained)
53+
losses = map(x -> x.loss, stats)
54+
```

docs/src/api.md

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,6 @@
33
```@index
44
```
55

6-
## Main Function
7-
8-
```@docs
9-
NormalizingFlows.train_flow
10-
```
11-
12-
The flow object can be constructed by `transformed` function in `Bijectors.jl`.
13-
For example, for Gaussian VI, we can construct the flow as follows:
14-
15-
```julia
16-
using Distributions, Bijectors
17-
T = Float32
18-
@leaf MvNormal # to prevent params in q₀ from being optimized
19-
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
20-
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) Bijectors.Scale(ones(T, 2)))
21-
```
22-
23-
To train the Gaussian VI targeting distribution `p` via ELBO maximization, run:
24-
25-
```julia
26-
using NormalizingFlows, Optimisers
27-
28-
sample_per_iter = 10
29-
flow_trained, stats, _ = train_flow(
30-
elbo,
31-
flow,
32-
logp,
33-
sample_per_iter;
34-
max_iters = 2_000,
35-
optimiser = Optimisers.ADAM(0.01 * one(T)),
36-
)
37-
```
38-
39-
## Coupling-based flows (default constructors)
40-
41-
These helpers construct commonly used coupling-based flows with sensible defaults.
42-
43-
```@docs
44-
NormalizingFlows.realnvp
45-
NormalizingFlows.nsf
46-
NormalizingFlows.RealNVP_layer
47-
NormalizingFlows.NSF_layer
48-
NormalizingFlows.AffineCoupling
49-
NormalizingFlows.NeuralSplineCoupling
50-
NormalizingFlows.create_flow
51-
```
52-
536
## Variational Objectives
547

558
We provide ELBO (reverse KL) and expected log-likelihood (forward KL). You can also
@@ -100,3 +53,38 @@ NormalizingFlows.loglikelihood
10053
```@docs
10154
NormalizingFlows.optimize
10255
```
56+
57+
58+
## Available Flows
59+
60+
`NormalizingFlows.jl` provides two commonly used normalizing flows: `RealNVP` and
61+
`Neural Spline Flow (NSF)`.
62+
63+
### RealNVP (Affine Coupling Flow)
64+
65+
These helpers construct commonly used coupling-based flows with sensible defaults.
66+
67+
```@docs
68+
NormalizingFlows.realnvp
69+
NormalizingFlows.RealNVP_layer
70+
NormalizingFlows.AffineCoupling
71+
```
72+
73+
### Neural Spline Flow (NSF)
74+
75+
```@docs
76+
NormalizingFlows.nsf
77+
NormalizingFlows.NSF_layer
78+
NormalizingFlows.NeuralSplineCoupling
79+
```
80+
81+
## Utility Functions
82+
83+
```@docs
84+
NormalizingFlows.create_flow
85+
```
86+
87+
```@docs
88+
NormalizingFlows.fnn
89+
```
90+

docs/src/customized_layer.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ for more details.
1212

1313

1414
In this tutorial, we demonstrate how to define a customized normalizing flow
15-
layer -- an `Affine Coupling Layer` (Dinh *et al.*, 2016) -- using `Bijectors.jl` and `Flux.jl`.
15+
layer -- an `Affine Coupling Layer` -- using `Bijectors.jl` and `Flux.jl`,
16+
which is the building block of the RealNVP flow [^LJS2017].
17+
It's worth mentioning that the [`realnvp`](@ref) implemented in `NormalizingFlows.jl`
18+
is slightly different from this tutorial with some optimization for the training stability
19+
and performance.
1620

1721
## Affine Coupling Flow
1822

@@ -176,5 +180,4 @@ logpdf(flow, x[:,1])
176180

177181

178182
## Reference
179-
Dinh, L., Sohl-Dickstein, J. and Bengio, S., 2016. *Density estimation using real nvp.*
180-
arXiv:1605.08803.
183+
[^LJS2017]: Dinh, L., Sohl-Dickstein, J. and Bengio, S., 2017. Density estimation using real nvp. in *ICLR*

0 commit comments

Comments
 (0)