Skip to content

Commit 3504009

Browse files
authored
make realnvp and nsf layers as part of the pkg (#53)
* make realnvp and nsf layers as part of the pkg * import Functors * add fully connect nn constructor * update realnvp default constructor * minor typo fix * minor update of realnnvp constructor and add some doc * fixing bug in Distribution types * exclude nsf for now * minor ed in nsf * fix typo in realnvp * add realnvp test * export nsf layer * update demos, debugging mooncake with elbo * add AD tests for realnvp elbo * wip debug mooncake on coupling layers * found that bug revealed by mooncake 0.4.124 * add compat mooncake v0.4.142, fixed the autograd error on nested struct * add mooncake compat >= v0.4.142 * add nsf interface * fix a typo in elbo_batch signiture * rm redundant comments * making target adapting to the chosen Floating type automatically * rm redundant flux dependencies * add new nsf implementation and demo; much faster than the original nsf * rm redundant flux from realnvp demo * dump the previous nsf implementation * add test for nsf * add ad test for nsf * fix typo in nsf test * fix nsf test error regarding rand() * relax rtol for nsf invertibility error in FLoat32 * update doc * wip doc build erro * updating docs * update gha perms * minor ed on docs * update docs * update readme * incorpoerate comments from red-portal and sunxd3 * fix test error * add planar and radial flow; updating docs * fixed error in doc for creat_flow * rm redundant comments * change Any[] to Flux.Dense[] * minor comment update
1 parent e0eb8ab commit 3504009

29 files changed

+1421
-504
lines changed

.github/workflows/Docs.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ concurrency:
1515

1616
permissions:
1717
contents: write
18-
pull-requests: read
18+
pull-requests: write
1919

2020
jobs:
2121
docs:
@@ -25,9 +25,10 @@ jobs:
2525
- name: Build and deploy Documenter.jl docs
2626
uses: TuringLang/actions/DocsDocumenter@main
2727

28-
- run: |
29-
julia --project=docs -e '
30-
using Documenter: DocMeta, doctest
31-
using NormalizingFlows
32-
DocMeta.setdocmeta!(NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true)
33-
doctest(NormalizingFlows)'
28+
- name: Run doctests
29+
shell: julia --project=docs --color=yes {0}
30+
run: |
31+
using Documenter: DocMeta, doctest
32+
using NormalizingFlows
33+
DocMeta.setdocmeta!(NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true)
34+
doctest(NormalizingFlows)

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
name = "NormalizingFlows"
22
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
3-
version = "0.2.1"
3+
version = "0.2.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
12+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1113
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
1215
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1316
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1417
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -27,7 +30,10 @@ CUDA = "5"
2730
DifferentiationInterface = "0.6, 0.7"
2831
Distributions = "0.25"
2932
DocStringExtensions = "0.9"
33+
Flux = "0.16"
34+
Functors = "0.5.2"
35+
MonotonicSplines = "0.3.3"
3036
Optimisers = "0.2.16, 0.3, 0.4"
3137
ProgressMeter = "1.0.0"
3238
StatsBase = "0.33, 0.34"
33-
julia = "1.10"
39+
julia = "1.10.8"

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
55

66

7-
**Last updated: 2025-Mar-04**
7+
**Last updated: 2025-Aug-08**
88

99
A normalizing flow library for Julia.
1010

@@ -17,6 +17,8 @@ without being tied to specific probabilistic programming frameworks or applicati
1717

1818
See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for more.
1919

20+
We also provide several demos and examples in [example](https://github.com/TuringLang/NormalizingFlows.jl/tree/main/example).
21+
2022
## Installation
2123
To install the package, run the following command in the Julia REPL:
2224
```julia
@@ -90,3 +92,4 @@ where one wants to learn the underlying distribution of some data.
9092
- [Flux.jl](https://fluxml.ai/Flux.jl/stable/)
9193
- [Optimisers.jl](https://github.com/FluxML/Optimisers.jl)
9294
- [AdvancedVI.jl](https://github.com/TuringLang/AdvancedVI.jl)
95+
- [MonotonicSplines.jl](https://github.com/bat/MonotonicSplines.jl)

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
66
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
89
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

docs/make.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
using NormalizingFlows
22
using Documenter
33

4+
using Random
5+
using Distributions
6+
47
DocMeta.setdocmeta!(
58
NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true
69
)
710

811
makedocs(;
912
modules=[NormalizingFlows],
10-
repo="https://github.com/TuringLang/NormalizingFlows.jl/blob/{commit}{path}#{line}",
1113
sitename="NormalizingFlows.jl",
12-
format=Documenter.HTML(),
14+
repo="https://github.com/TuringLang/NormalizingFlows.jl/blob/{commit}{path}#{line}",
15+
format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"),
1316
pages=[
1417
"Home" => "index.md",
18+
"General usage" => "usage.md",
1519
"API" => "api.md",
16-
"Example" => "example.md",
20+
"Example" => [
21+
"Planar Flow" => "PlanarFlow.md",
22+
"RealNVP" => "RealNVP.md",
23+
"Neural Spline Flow" => "NSF.md",
24+
],
1725
"Customize your own flow layer" => "customized_layer.md",
1826
],
1927
checkdocs=:exports,

docs/src/NSF.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Demo of NSF on 2D Banana Distribution
2+
3+
```julia
4+
using Random, Distributions, LinearAlgebra
5+
using Functors
6+
using Optimisers, ADTypes
7+
using Zygote
8+
using NormalizingFlows
9+
10+
11+
target = Banana(2, one(T), 100one(T))
12+
logp = Base.Fix1(logpdf, target)
13+
14+
######################################
15+
# learn the target using Neural Spline Flow
16+
######################################
17+
@leaf MvNormal
18+
q0 = MvNormal(zeros(T, 2), I)
19+
20+
21+
flow = nsf(q0; paramtype=T)
22+
flow_untrained = deepcopy(flow)
23+
######################################
24+
# start training
25+
######################################
26+
sample_per_iter = 64
27+
28+
# callback function to log training progress
29+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
30+
# nsf only supports AutoZygote
31+
adtype = ADTypes.AutoZygote()
32+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
33+
flow_trained, stats, _ = train_flow(
34+
elbo_batch,
35+
flow,
36+
logp,
37+
sample_per_iter;
38+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
39+
optimiser=Optimisers.Adam(1e-4),
40+
ADbackend=adtype,
41+
show_progress=true,
42+
callback=cb,
43+
hasconverged=checkconv,
44+
)
45+
θ, re = Optimisers.destructure(flow_trained)
46+
losses = map(x -> x.loss, stats)
47+
```

docs/src/PlanarFlow.md

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Planar Flow on a 2D Banana Distribution
2+
3+
This example demonstrates learning a synthetic 2D banana distribution with a planar normalizing flow [^RM2015] by maximizing the Evidence Lower BOund (ELBO).
4+
5+
The two required ingredients are:
6+
7+
- A log-density function `logp` for the target distribution.
8+
- A parametrised invertible transformation (the planar flow) applied to a simple base distribution.
9+
10+
## Target Distribution
11+
12+
The banana target used here is defined in `example/targets/banana.jl` (see source for details):
13+
14+
```julia
15+
using Random, Distributions
16+
Random.seed!(123)
17+
18+
target = Banana(2, 1.0, 10.0) # (dimension, nonlinearity, scale)
19+
logp = Base.Fix1(logpdf, target)
20+
```
21+
22+
You can visualise its contour and samples (figure shipped as `banana.png`).
23+
24+
![Banana](banana.png)
25+
26+
## Planar Flow
27+
28+
A planar flow of length N applies a sequence of planar layers to a base distribution q₀:
29+
30+
```math
31+
T_{n,\theta_n}(x) = x + u_n \tanh(w_n^T x + b_n), \qquad n = 1,\ldots,N.
32+
```
33+
34+
Parameters θₙ = (uₙ, wₙ, bₙ) are learned. `Bijectors.jl` provides `PlanarLayer`.
35+
36+
```julia
37+
using Bijectors
38+
using Functors # for @leaf
39+
40+
function create_planar_flow(n_layers::Int, q₀)
41+
d = length(q₀)
42+
Ls = [PlanarLayer(d) for _ in 1:n_layers]
43+
ts = reduce(, Ls) # alternatively: FunctionChains.fchain(Ls)
44+
return transformed(q₀, ts)
45+
end
46+
47+
@leaf MvNormal # prevent updating base distribution parameters
48+
q₀ = MvNormal(zeros(2), ones(2))
49+
flow = create_planar_flow(10, q₀)
50+
flow_untrained = deepcopy(flow) # keep copy for comparison
51+
```
52+
53+
If you build *many* layers (e.g. > ~30) you may reduce compilation time by using `FunctionChains.jl`:
54+
55+
```julia
56+
# uncomment the following lines to use FunctionChains
57+
# using FunctionChains
58+
# ts = fchain([PlanarLayer(d) for _ in 1:n_layers])
59+
```
60+
See [this comment](https://github.com/TuringLang/NormalizingFlows.jl/blob/8f4371d48228adf368d851e221af076ff929f1cf/src/NormalizingFlows.jl#L52)
61+
for how the compilation time might be a concern.
62+
63+
## Training the Flow
64+
65+
We maximize the ELBO (here using the minibatch estimator `elbo_batch`) with the generic `train_flow` interface.
66+
67+
```julia
68+
using NormalizingFlows
69+
using ADTypes, Optimisers
70+
using Mooncake
71+
72+
sample_per_iter = 32
73+
adtype = ADTypes.AutoMooncake(; config=Mooncake.Config()) # try AutoZygote() / AutoForwardDiff() / etc.
74+
# optional: callback function to track the batch size per iteration and the AD backend used
75+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter, ad=adtype)
76+
# optional: defined stopping criteria when the gradient norm is less than 1e-3
77+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3
78+
79+
flow_trained, stats, _ = train_flow(
80+
elbo_batch,
81+
flow,
82+
logp,
83+
sample_per_iter;
84+
max_iters = 20_000,
85+
optimiser = Optimisers.Adam(1e-2),
86+
ADbackend = adtype,
87+
callback = cb,
88+
hasconverged = checkconv,
89+
show_progress = false,
90+
)
91+
92+
losses = map(x -> x.loss, stats)
93+
```
94+
95+
Plot the losses (negative ELBO):
96+
97+
```julia
98+
using Plots
99+
plot(losses; xlabel = "iteration", ylabel = "negative ELBO", label = "", lw = 2)
100+
```
101+
102+
![elbo](elbo.png)
103+
104+
## Evaluating the Trained Flow
105+
106+
The trained flow is a `Bijectors.TransformedDistribution`, so we can call `rand` to draw iid samples and call `logpdf` to evaluate the log-density function of the flow.
107+
See [documentation of `Bijectors.jl`](https://turinglang.org/Bijectors.jl/dev/distributions/) for details.
108+
```julia
109+
n_samples = 1_000
110+
samples_trained = rand(flow_trained, n_samples)
111+
samples_untrained = rand(flow_untrained, n_samples)
112+
samples_true = rand(target, n_samples)
113+
```
114+
115+
Simple visual comparison:
116+
117+
```julia
118+
using Plots
119+
scatter(samples_true[1, :], samples_true[2, :]; label="Target", ms=2, alpha=0.5)
120+
scatter!(samples_untrained[1, :], samples_untrained[2, :]; label="Untrained", ms=2, alpha=0.5)
121+
scatter!(samples_trained[1, :], samples_trained[2, :]; label="Trained", ms=2, alpha=0.5)
122+
plot!(title = "Planar Flow: Before vs After Training", xlabel = "x₁", ylabel = "x₂", legend = :topleft)
123+
```
124+
125+
![compare](comparison.png)
126+
127+
## Notes
128+
129+
- Use `elbo` instead of `elbo_batch` for a single-sample estimator.
130+
- Switch AD backends by changing `adtype` (see `ADTypes.jl`).
131+
- Marking the base distribution with `@leaf` prevents its parameters from being updated during training.
132+
133+
## Reference
134+
135+
[^RM2015]: Rezende, D. & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML.

docs/src/RealNVP.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Demo of RealNVP on 2D Banana Distribution
2+
3+
```julia
4+
using Random, Distributions, LinearAlgebra
5+
using Functors
6+
using Optimisers, ADTypes
7+
using Mooncake
8+
using NormalizingFlows
9+
10+
11+
target = Banana(2, one(T), 100one(T))
12+
logp = Base.Fix1(logpdf, target)
13+
14+
######################################
15+
# set up the RealNVP
16+
######################################
17+
@leaf MvNormal
18+
q0 = MvNormal(zeros(T, 2), I)
19+
20+
d = 2
21+
hdims = [16, 16]
22+
nlayers = 3
23+
24+
# use NormalizingFlows.realnvp to create a RealNVP flow
25+
flow = realnvp(q0, hdims, nlayers; paramtype=T)
26+
flow_untrained = deepcopy(flow)
27+
28+
29+
######################################
30+
# start training
31+
######################################
32+
sample_per_iter = 16
33+
34+
# callback function to log training progress
35+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
36+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
37+
38+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
39+
flow_trained, stats, _ = train_flow(
40+
rng,
41+
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
42+
flow,
43+
logp,
44+
sample_per_iter;
45+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
46+
optimiser=Optimisers.Adam(5e-4),
47+
ADbackend=adtype,
48+
show_progress=true,
49+
callback=cb,
50+
hasconverged=checkconv,
51+
)
52+
θ, re = Optimisers.destructure(flow_trained)
53+
losses = map(x -> x.loss, stats)
54+
```

0 commit comments

Comments
 (0)