Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8db828f
make realnvp and nsf layers as part of the pkg
zuhengxu Jul 5, 2025
574e257
import Functors
zuhengxu Jul 5, 2025
cd63565
add fully connect nn constructor
zuhengxu Jul 13, 2025
e03213c
update realnvp default constructor
zuhengxu Jul 13, 2025
34a964e
minor typo fix
zuhengxu Jul 13, 2025
1c4118b
minor update of realnnvp constructor and add some doc
zuhengxu Jul 14, 2025
d6b86cb
fixing bug in Distribution types
zuhengxu Jul 14, 2025
aa2adeb
exclude nsf for now
zuhengxu Jul 14, 2025
00ca29c
minor ed in nsf
zuhengxu Jul 14, 2025
731e657
fix typo in realnvp
zuhengxu Jul 14, 2025
e4fa67b
add realnvp test
zuhengxu Jul 14, 2025
1dfebd1
export nsf layer
zuhengxu Jul 14, 2025
55fb607
update demos, debugging mooncake with elbo
zuhengxu Jul 14, 2025
a2f6fbe
add AD tests for realnvp elbo
zuhengxu Jul 14, 2025
e39b8a8
wip debug mooncake on coupling layers
zuhengxu Jul 23, 2025
84ce45f
found that bug revealed by mooncake 0.4.124
zuhengxu Jul 25, 2025
9ba0a3b
add compat mooncake v0.4.142, fixed the autograd error on nested struct
zuhengxu Jul 31, 2025
81000b9
add mooncake compat >= v0.4.142
zuhengxu Aug 2, 2025
4caae49
add nsf interface
zuhengxu Aug 3, 2025
cf2e674
fix a typo in elbo_batch signiture
zuhengxu Aug 3, 2025
7f9c382
rm redundant comments
zuhengxu Aug 3, 2025
9f0cbad
making target adapting to the chosen Floating type automatically
zuhengxu Aug 3, 2025
99a0fed
rm redundant flux dependencies
zuhengxu Aug 3, 2025
c4128fa
add new nsf implementation and demo; much faster than the original nsf
zuhengxu Aug 3, 2025
1f30b33
rm redundant flux from realnvp demo
zuhengxu Aug 4, 2025
2903b83
dump the previous nsf implementation
zuhengxu Aug 4, 2025
48bc3d3
add test for nsf
zuhengxu Aug 4, 2025
9494de1
add ad test for nsf
zuhengxu Aug 4, 2025
8f61fc9
fix typo in nsf test
zuhengxu Aug 4, 2025
977caaf
fix nsf test error regarding rand()
zuhengxu Aug 4, 2025
0b9e656
relax rtol for nsf invertibility error in FLoat32
zuhengxu Aug 4, 2025
48829ad
update doc
zuhengxu Aug 8, 2025
4e6bfbe
wip doc build erro
zuhengxu Aug 8, 2025
eb19664
updating docs
zuhengxu Aug 8, 2025
ae53100
update gha perms
zuhengxu Aug 8, 2025
b8b229b
minor ed on docs
zuhengxu Aug 8, 2025
2d5b9c5
update docs
zuhengxu Aug 9, 2025
8a04147
update readme
zuhengxu Aug 9, 2025
2431e57
incorpoerate comments from red-portal and sunxd3
zuhengxu Aug 20, 2025
4dec51a
fix test error
zuhengxu Aug 20, 2025
fb6b80b
add planar and radial flow; updating docs
zuhengxu Aug 20, 2025
862c6dc
fixed error in doc for creat_flow
zuhengxu Aug 20, 2025
dc73605
rm redundant comments
zuhengxu Aug 20, 2025
5753150
change Any[] to Flux.Dense[]
zuhengxu Aug 20, 2025
d9525e7
minor comment update
zuhengxu Aug 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ concurrency:

permissions:
contents: write
pull-requests: read
pull-requests: write

jobs:
docs:
Expand All @@ -25,9 +25,10 @@ jobs:
- name: Build and deploy Documenter.jl docs
uses: TuringLang/actions/DocsDocumenter@main

- run: |
julia --project=docs -e '
using Documenter: DocMeta, doctest
using NormalizingFlows
DocMeta.setdocmeta!(NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true)
doctest(NormalizingFlows)'
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
using Documenter: DocMeta, doctest
using NormalizingFlows
DocMeta.setdocmeta!(NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true)
doctest(NormalizingFlows)
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.2.1"
version = "0.2.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -27,7 +30,10 @@ CUDA = "5"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
Flux = "0.16"
Functors = "0.5.2"
MonotonicSplines = "0.3.3"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
StatsBase = "0.33, 0.34"
julia = "1.10"
julia = "1.10.8"
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)


**Last updated: 2025-Mar-04**
**Last updated: 2025-Aug-08**

A normalizing flow library for Julia.

Expand All @@ -17,6 +17,8 @@ without being tied to specific probabilistic programming frameworks or applicati

See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for more.

We also provide several demos and examples in [example](https://github.com/TuringLang/NormalizingFlows.jl/tree/main/example).

## Installation
To install the package, run the following command in the Julia REPL:
```julia
Expand Down Expand Up @@ -90,3 +92,4 @@ where one wants to learn the underlying distribution of some data.
- [Flux.jl](https://fluxml.ai/Flux.jl/stable/)
- [Optimisers.jl](https://github.com/FluxML/Optimisers.jl)
- [AdvancedVI.jl](https://github.com/TuringLang/AdvancedVI.jl)
- [MonotonicSplines.jl](https://github.com/bat/MonotonicSplines.jl)
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14 changes: 11 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
using NormalizingFlows
using Documenter

using Random
using Distributions

DocMeta.setdocmeta!(
NormalizingFlows, :DocTestSetup, :(using NormalizingFlows); recursive=true
)

makedocs(;
modules=[NormalizingFlows],
repo="https://github.com/TuringLang/NormalizingFlows.jl/blob/{commit}{path}#{line}",
sitename="NormalizingFlows.jl",
format=Documenter.HTML(),
repo="https://github.com/TuringLang/NormalizingFlows.jl/blob/{commit}{path}#{line}",
format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"),
pages=[
"Home" => "index.md",
"General usage" => "usage.md",
"API" => "api.md",
"Example" => "example.md",
"Example" => [
"Planar Flow" => "PlanarFlow.md",
"RealNVP" => "RealNVP.md",
"Neural Spline Flow" => "NSF.md",
],
"Customize your own flow layer" => "customized_layer.md",
],
checkdocs=:exports,
Expand Down
47 changes: 47 additions & 0 deletions docs/src/NSF.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Demo of NSF on 2D Banana Distribution

```julia
using Random, Distributions, LinearAlgebra
using Functors
using Optimisers, ADTypes
using Zygote
using NormalizingFlows


target = Banana(2, one(T), 100one(T))
logp = Base.Fix1(logpdf, target)

######################################
# learn the target using Neural Spline Flow
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), I)


flow = nsf(q0; paramtype=T)
flow_untrained = deepcopy(flow)
######################################
# start training
######################################
sample_per_iter = 64

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
# nsf only supports AutoZygote
adtype = ADTypes.AutoZygote()
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
elbo_batch,
flow,
logp,
sample_per_iter;
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
optimiser=Optimisers.Adam(1e-4),
ADbackend=adtype,
show_progress=true,
callback=cb,
hasconverged=checkconv,
)
θ, re = Optimisers.destructure(flow_trained)
losses = map(x -> x.loss, stats)
```
135 changes: 135 additions & 0 deletions docs/src/PlanarFlow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Planar Flow on a 2D Banana Distribution

This example demonstrates learning a synthetic 2D banana distribution with a planar normalizing flow [^RM2015] by maximizing the Evidence Lower BOund (ELBO).

The two required ingredients are:

- A log-density function `logp` for the target distribution.
- A parametrised invertible transformation (the planar flow) applied to a simple base distribution.

## Target Distribution

The banana target used here is defined in `example/targets/banana.jl` (see source for details):

```julia
using Random, Distributions
Random.seed!(123)

target = Banana(2, 1.0, 10.0) # (dimension, nonlinearity, scale)
logp = Base.Fix1(logpdf, target)
```

You can visualise its contour and samples (figure shipped as `banana.png`).

![Banana](banana.png)

## Planar Flow

A planar flow of length N applies a sequence of planar layers to a base distribution q₀:

```math
T_{n,\theta_n}(x) = x + u_n \tanh(w_n^T x + b_n), \qquad n = 1,\ldots,N.
```

Parameters θₙ = (uₙ, wₙ, bₙ) are learned. `Bijectors.jl` provides `PlanarLayer`.

```julia
using Bijectors
using Functors # for @leaf

function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Ls = [PlanarLayer(d) for _ in 1:n_layers]
ts = reduce(∘, Ls) # alternatively: FunctionChains.fchain(Ls)
return transformed(q₀, ts)
end

@leaf MvNormal # prevent updating base distribution parameters
q₀ = MvNormal(zeros(2), ones(2))
flow = create_planar_flow(10, q₀)
flow_untrained = deepcopy(flow) # keep copy for comparison
```

If you build *many* layers (e.g. > ~30) you may reduce compilation time by using `FunctionChains.jl`:

```julia
# uncomment the following lines to use FunctionChains
# using FunctionChains
# ts = fchain([PlanarLayer(d) for _ in 1:n_layers])
```
See [this comment](https://github.com/TuringLang/NormalizingFlows.jl/blob/8f4371d48228adf368d851e221af076ff929f1cf/src/NormalizingFlows.jl#L52)
for how the compilation time might be a concern.

## Training the Flow

We maximize the ELBO (here using the minibatch estimator `elbo_batch`) with the generic `train_flow` interface.

```julia
using NormalizingFlows
using ADTypes, Optimisers
using Mooncake

sample_per_iter = 32
adtype = ADTypes.AutoMooncake(; config=Mooncake.Config()) # try AutoZygote() / AutoForwardDiff() / etc.
# optional: callback function to track the batch size per iteration and the AD backend used
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter, ad=adtype)
# optional: defined stopping criteria when the gradient norm is less than 1e-3
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3

flow_trained, stats, _ = train_flow(
elbo_batch,
flow,
logp,
sample_per_iter;
max_iters = 20_000,
optimiser = Optimisers.Adam(1e-2),
ADbackend = adtype,
callback = cb,
hasconverged = checkconv,
show_progress = false,
)

losses = map(x -> x.loss, stats)
```

Plot the losses (negative ELBO):

```julia
using Plots
plot(losses; xlabel = "iteration", ylabel = "negative ELBO", label = "", lw = 2)
```

![elbo](elbo.png)

## Evaluating the Trained Flow

The trained flow is a `Bijectors.TransformedDistribution`, so we can call `rand` to draw iid samples and call `logpdf` to evaluate the log-density function of the flow.
See [documentation of `Bijectors.jl`](https://turinglang.org/Bijectors.jl/dev/distributions/) for details.
```julia
n_samples = 1_000
samples_trained = rand(flow_trained, n_samples)
samples_untrained = rand(flow_untrained, n_samples)
samples_true = rand(target, n_samples)
```

Simple visual comparison:

```julia
using Plots
scatter(samples_true[1, :], samples_true[2, :]; label="Target", ms=2, alpha=0.5)
scatter!(samples_untrained[1, :], samples_untrained[2, :]; label="Untrained", ms=2, alpha=0.5)
scatter!(samples_trained[1, :], samples_trained[2, :]; label="Trained", ms=2, alpha=0.5)
plot!(title = "Planar Flow: Before vs After Training", xlabel = "x₁", ylabel = "x₂", legend = :topleft)
```

![compare](comparison.png)

## Notes

- Use `elbo` instead of `elbo_batch` for a single-sample estimator.
- Switch AD backends by changing `adtype` (see `ADTypes.jl`).
- Marking the base distribution with `@leaf` prevents its parameters from being updated during training.

## Reference

[^RM2015]: Rezende, D. & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML.
54 changes: 54 additions & 0 deletions docs/src/RealNVP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Demo of RealNVP on 2D Banana Distribution

```julia
using Random, Distributions, LinearAlgebra
using Functors
using Optimisers, ADTypes
using Mooncake
using NormalizingFlows


target = Banana(2, one(T), 100one(T))
logp = Base.Fix1(logpdf, target)

######################################
# set up the RealNVP
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), I)

d = 2
hdims = [16, 16]
nlayers = 3

# use NormalizingFlows.realnvp to create a RealNVP flow
flow = realnvp(q0, hdims, nlayers; paramtype=T)
flow_untrained = deepcopy(flow)


######################################
# start training
######################################
sample_per_iter = 16

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())

checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
rng,
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
flow,
logp,
sample_per_iter;
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
optimiser=Optimisers.Adam(5e-4),
ADbackend=adtype,
show_progress=true,
callback=cb,
hasconverged=checkconv,
)
θ, re = Optimisers.destructure(flow_trained)
losses = map(x -> x.loss, stats)
```
Loading