-
Notifications
You must be signed in to change notification settings - Fork 5
make realnvp and nsf layers as part of the pkg #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 41 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
8db828f
make realnvp and nsf layers as part of the pkg
zuhengxu 574e257
import Functors
zuhengxu cd63565
add fully connect nn constructor
zuhengxu e03213c
update realnvp default constructor
zuhengxu 34a964e
minor typo fix
zuhengxu 1c4118b
minor update of realnnvp constructor and add some doc
zuhengxu d6b86cb
fixing bug in Distribution types
zuhengxu aa2adeb
exclude nsf for now
zuhengxu 00ca29c
minor ed in nsf
zuhengxu 731e657
fix typo in realnvp
zuhengxu e4fa67b
add realnvp test
zuhengxu 1dfebd1
export nsf layer
zuhengxu 55fb607
update demos, debugging mooncake with elbo
zuhengxu a2f6fbe
add AD tests for realnvp elbo
zuhengxu e39b8a8
wip debug mooncake on coupling layers
zuhengxu 84ce45f
found that bug revealed by mooncake 0.4.124
zuhengxu 9ba0a3b
add compat mooncake v0.4.142, fixed the autograd error on nested struct
zuhengxu 81000b9
add mooncake compat >= v0.4.142
zuhengxu 4caae49
add nsf interface
zuhengxu cf2e674
fix a typo in elbo_batch signiture
zuhengxu 7f9c382
rm redundant comments
zuhengxu 9f0cbad
making target adapting to the chosen Floating type automatically
zuhengxu 99a0fed
rm redundant flux dependencies
zuhengxu c4128fa
add new nsf implementation and demo; much faster than the original nsf
zuhengxu 1f30b33
rm redundant flux from realnvp demo
zuhengxu 2903b83
dump the previous nsf implementation
zuhengxu 48bc3d3
add test for nsf
zuhengxu 9494de1
add ad test for nsf
zuhengxu 8f61fc9
fix typo in nsf test
zuhengxu 977caaf
fix nsf test error regarding rand()
zuhengxu 0b9e656
relax rtol for nsf invertibility error in FLoat32
zuhengxu 48829ad
update doc
zuhengxu 4e6bfbe
wip doc build erro
zuhengxu eb19664
updating docs
zuhengxu ae53100
update gha perms
zuhengxu b8b229b
minor ed on docs
zuhengxu 2d5b9c5
update docs
zuhengxu 8a04147
update readme
zuhengxu 2431e57
incorpoerate comments from red-portal and sunxd3
zuhengxu 4dec51a
fix test error
zuhengxu fb6b80b
add planar and radial flow; updating docs
zuhengxu 862c6dc
fixed error in doc for creat_flow
zuhengxu dc73605
rm redundant comments
zuhengxu 5753150
change Any[] to Flux.Dense[]
zuhengxu d9525e7
minor comment update
zuhengxu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Demo of NSF on 2D Banana Distribution | ||
|
||
```julia | ||
using Random, Distributions, LinearAlgebra | ||
using Functors | ||
using Optimisers, ADTypes | ||
using Zygote | ||
using NormalizingFlows | ||
|
||
|
||
target = Banana(2, one(T), 100one(T)) | ||
logp = Base.Fix1(logpdf, target) | ||
|
||
###################################### | ||
# learn the target using Neural Spline Flow | ||
###################################### | ||
@leaf MvNormal | ||
q0 = MvNormal(zeros(T, 2), I) | ||
|
||
|
||
flow = nsf(q0; paramtype=T) | ||
flow_untrained = deepcopy(flow) | ||
###################################### | ||
# start training | ||
###################################### | ||
sample_per_iter = 64 | ||
|
||
# callback function to log training progress | ||
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) | ||
# nsf only supports AutoZygote | ||
adtype = ADTypes.AutoZygote() | ||
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 | ||
flow_trained, stats, _ = train_flow( | ||
elbo_batch, | ||
flow, | ||
logp, | ||
sample_per_iter; | ||
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results | ||
optimiser=Optimisers.Adam(1e-4), | ||
ADbackend=adtype, | ||
show_progress=true, | ||
callback=cb, | ||
hasconverged=checkconv, | ||
) | ||
θ, re = Optimisers.destructure(flow_trained) | ||
losses = map(x -> x.loss, stats) | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Planar Flow on a 2D Banana Distribution | ||
|
||
This example demonstrates learning a synthetic 2D banana distribution with a planar normalizing flow [^RM2015] by maximizing the Evidence Lower BOund (ELBO). | ||
|
||
The two required ingredients are: | ||
|
||
- A log-density function `logp` for the target distribution. | ||
- A parametrised invertible transformation (the planar flow) applied to a simple base distribution. | ||
|
||
## Target Distribution | ||
|
||
The banana target used here is defined in `example/targets/banana.jl` (see source for details): | ||
|
||
```julia | ||
using Random, Distributions | ||
Random.seed!(123) | ||
|
||
target = Banana(2, 1.0, 10.0) # (dimension, nonlinearity, scale) | ||
logp = Base.Fix1(logpdf, target) | ||
``` | ||
|
||
You can visualise its contour and samples (figure shipped as `banana.png`). | ||
|
||
 | ||
|
||
## Planar Flow | ||
|
||
A planar flow of length N applies a sequence of planar layers to a base distribution q₀: | ||
|
||
```math | ||
T_{n,\theta_n}(x) = x + u_n \tanh(w_n^T x + b_n), \qquad n = 1,\ldots,N. | ||
``` | ||
|
||
Parameters θₙ = (uₙ, wₙ, bₙ) are learned. `Bijectors.jl` provides `PlanarLayer`. | ||
|
||
```julia | ||
using Bijectors | ||
using Functors # for @leaf | ||
|
||
function create_planar_flow(n_layers::Int, q₀) | ||
d = length(q₀) | ||
Ls = [PlanarLayer(d) for _ in 1:n_layers] | ||
ts = reduce(∘, Ls) # alternatively: FunctionChains.fchain(Ls) | ||
return transformed(q₀, ts) | ||
end | ||
|
||
@leaf MvNormal # prevent updating base distribution parameters | ||
q₀ = MvNormal(zeros(2), ones(2)) | ||
flow = create_planar_flow(10, q₀) | ||
flow_untrained = deepcopy(flow) # keep copy for comparison | ||
``` | ||
|
||
If you build *many* layers (e.g. > ~30) you may reduce compilation time by using `FunctionChains.jl`: | ||
|
||
```julia | ||
# uncomment the following lines to use FunctionChains | ||
# using FunctionChains | ||
# ts = fchain([PlanarLayer(d) for _ in 1:n_layers]) | ||
``` | ||
See [this comment](https://github.com/TuringLang/NormalizingFlows.jl/blob/8f4371d48228adf368d851e221af076ff929f1cf/src/NormalizingFlows.jl#L52) | ||
for how the compilation time might be a concern. | ||
|
||
## Training the Flow | ||
|
||
We maximize the ELBO (here using the minibatch estimator `elbo_batch`) with the generic `train_flow` interface. | ||
|
||
```julia | ||
using NormalizingFlows | ||
using ADTypes, Optimisers | ||
using Mooncake | ||
|
||
sample_per_iter = 32 | ||
adtype = ADTypes.AutoMooncake(; config=Mooncake.Config()) # try AutoZygote() / AutoForwardDiff() / etc. | ||
# optional: callback function to track the batch size per iteration and the AD backend used | ||
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter, ad=adtype) | ||
# optional: defined stopping criteria when the gradient norm is less than 1e-3 | ||
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3 | ||
|
||
flow_trained, stats, _ = train_flow( | ||
elbo_batch, | ||
flow, | ||
logp, | ||
sample_per_iter; | ||
max_iters = 20_000, | ||
optimiser = Optimisers.Adam(1e-2), | ||
ADbackend = adtype, | ||
callback = cb, | ||
hasconverged = checkconv, | ||
show_progress = false, | ||
) | ||
|
||
losses = map(x -> x.loss, stats) | ||
``` | ||
|
||
Plot the losses (negative ELBO): | ||
|
||
```julia | ||
using Plots | ||
plot(losses; xlabel = "iteration", ylabel = "negative ELBO", label = "", lw = 2) | ||
``` | ||
|
||
 | ||
|
||
## Evaluating the Trained Flow | ||
|
||
The trained flow is a `Bijectors.TransformedDistribution`, so we can call `rand` to draw iid samples and call `logpdf` to evaluate the log-density function of the flow. | ||
See [documentation of `Bijectors.jl`](https://turinglang.org/Bijectors.jl/dev/distributions/) for details. | ||
```julia | ||
n_samples = 1_000 | ||
samples_trained = rand(flow_trained, n_samples) | ||
samples_untrained = rand(flow_untrained, n_samples) | ||
samples_true = rand(target, n_samples) | ||
``` | ||
|
||
Simple visual comparison: | ||
|
||
```julia | ||
using Plots | ||
scatter(samples_true[1, :], samples_true[2, :]; label="Target", ms=2, alpha=0.5) | ||
scatter!(samples_untrained[1, :], samples_untrained[2, :]; label="Untrained", ms=2, alpha=0.5) | ||
scatter!(samples_trained[1, :], samples_trained[2, :]; label="Trained", ms=2, alpha=0.5) | ||
plot!(title = "Planar Flow: Before vs After Training", xlabel = "x₁", ylabel = "x₂", legend = :topleft) | ||
``` | ||
|
||
 | ||
|
||
## Notes | ||
|
||
- Use `elbo` instead of `elbo_batch` for a single-sample estimator. | ||
- Switch AD backends by changing `adtype` (see `ADTypes.jl`). | ||
- Marking the base distribution with `@leaf` prevents its parameters from being updated during training. | ||
|
||
## Reference | ||
|
||
[^RM2015]: Rezende, D. & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Demo of RealNVP on 2D Banana Distribution | ||
|
||
```julia | ||
using Random, Distributions, LinearAlgebra | ||
using Functors | ||
using Optimisers, ADTypes | ||
using Mooncake | ||
using NormalizingFlows | ||
|
||
|
||
target = Banana(2, one(T), 100one(T)) | ||
logp = Base.Fix1(logpdf, target) | ||
|
||
###################################### | ||
# set up the RealNVP | ||
###################################### | ||
@leaf MvNormal | ||
q0 = MvNormal(zeros(T, 2), I) | ||
|
||
d = 2 | ||
hdims = [16, 16] | ||
nlayers = 3 | ||
|
||
# use NormalizingFlows.realnvp to create a RealNVP flow | ||
flow = realnvp(q0, hdims, nlayers; paramtype=T) | ||
flow_untrained = deepcopy(flow) | ||
|
||
|
||
###################################### | ||
# start training | ||
###################################### | ||
sample_per_iter = 16 | ||
|
||
# callback function to log training progress | ||
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) | ||
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) | ||
|
||
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 | ||
flow_trained, stats, _ = train_flow( | ||
rng, | ||
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup | ||
flow, | ||
logp, | ||
sample_per_iter; | ||
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results | ||
optimiser=Optimisers.Adam(5e-4), | ||
ADbackend=adtype, | ||
show_progress=true, | ||
callback=cb, | ||
hasconverged=checkconv, | ||
) | ||
θ, re = Optimisers.destructure(flow_trained) | ||
losses = map(x -> x.loss, stats) | ||
``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.