Skip to content

Commit d336b8d

Browse files
zuhengxusunxd3
andauthored
more NF examples (#11)
* add support for hasconverged * fix test error * rm example/Manifest.toml * minor bug fix for trainig loop * test new stopping criterion * test convergent condition/ rm unready examples * rm julia test from CI * Revert "rm julia test from CI" This reverts commit 1c1c88a. * make autodiff pkgs as extension + require for bwd compat * debugging Ext * keep debugging * Fix AD package extension loading issues * Applying @devmotion's comment * patch last commit * patch for julia 1.6 * loading dep pkgs from main pkg instead of functions for explicitness * fixing test err * rm unready examples * update realnvp * minor ed * removing unnecessary import * refactor affinecoupling and example/ * debug affinecoupling flow * adapt to the updated autoforwarddiff to resolve test err * fix test err * add new implementation of affcoupling using Bijectors.Coupling * implement ham flow * finish hamflow implementation * minor update * rename hamflow.jl to hamiltonian_layer.jl * upadting readme * rm hamflow.jl * fix minor bugs in affine coupling layer * test affine coupling flow on banana * rename simple flow run files * update loglikelihood to fit in optimize interface * fix minor bugs in nsf_layer * rm unused data in nsf lfow * rm @view to avoid zygote mutation error * update optimize for direct minibatch training * add BatchNorm for real nvp and make real nvp compatible with batching * testing * update resblock * update resnet arch * minor update * udpate * fix affine bugs * add invertiblemlp * start testing invertiblenetworks * minor update * minor bug fix * update MLP to make it work * fix type instability in invertible MLP * update deep mlp * update deep hamflow result * update flows * add shadowing window computation * minor ed * refactoring training and setup files * minor setup change to allow convenient precision switch * get MLP figs and res' * obtain figs for ham * rm duplication * add some MLP result * add shf res * add mom normalization layer implementation * update shf res * add script for deep shf * update shf_big res * minor update * update log reg and new shf res with q0 being std normal * minor bug update * update sfh banana figs * add some lip constant figs for shf banana * add some lip constant reg * update shf banana/cross * update banana plots * update banana res * update cross res * update * update cross * update * update cross elbo res * update shadowing res * udpate banana and cross res * update banana figs * update cross res * update cross figs and res * update cross lp scaling figs * minor update * new cross figs * rm readme * update * rm many unrelated code * keep cleaning * fix warpgaussian logpdf error/making neal funnel logpdf working with mooncake * add easier model loading code * restructure example folder and refactor planar and radial examples * clean demos for realnvp/planar/radial/ fix a bug in nsf * rm enzyme dependency * tune nsf flow--enlarge B--to make it work * rm some useless files from HamVI * rename MLP_3layer to mlp3 for convenience * rename common.jl to utils * minor update file naming, and better nsf implementation * minor ed * rm redundant nsf file * shrink stepsize for nsf * cleaned hamiltonian flow * minor ed * rm redundant pkgs * MLutils and itertools from dependencies * rm useless example utils functions * add elbo_batch implementation; much faster for invertible NN based flows * rm simple unpack from ham flows * rm simple unpack dependency from example env * bump version * add test for elbo_batch * use elbo_batch for real_nvp; achieved 4+ times speed up * rm load_model and use constructor directly * fix doc building error * update DI bounds * fix cuda test error * reduce the iter number * add CI group for exampels * update examples CI * fixing CI bugs --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent 323ebb0 commit d336b8d

25 files changed

+835
-266
lines changed

.github/workflows/Examples.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: NF Examples
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags: ['*']
8+
pull_request:
9+
10+
concurrency:
11+
# Skip intermediate builds: always.
12+
# Cancel intermediate builds: only if it is a pull request build.
13+
group: ${{ github.workflow }}-${{ github.ref }}
14+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
15+
16+
jobs:
17+
run-examples:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- uses: actions/checkout@v4
21+
- uses: julia-actions/setup-julia@v2
22+
with:
23+
version: '1'
24+
arch: x64
25+
- uses: julia-actions/cache@v2
26+
- name: Run NF examples
27+
run: |
28+
cd example
29+
julia --project=. --color=yes -e '
30+
using Pkg;
31+
Pkg.develop(PackageSpec(path=joinpath(pwd(), "..")));
32+
Pkg.instantiate();
33+
@info "Running planar flow demo";
34+
include("demo_planar_flow.jl");
35+
@info "Running radial flow demo";
36+
include("demo_radial_flow.jl");
37+
@info "Running Real NVP demo";
38+
include("demo_RealNVP.jl");
39+
@info "Running neural spline flow demo";
40+
include("demo_neural_spline_flow.jl");
41+
@info "Running Hamiltonian flow demo";
42+
include("demo_hamiltonian_flow.jl");'

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/docs/build/
44
test/Manifest.toml
55
example/Manifest.toml
6+
example/LocalPreferences.toml
67

78
# Files generated by invoking Julia with --code-coverage
89
*.jl.cov

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NormalizingFlows"
22
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
3-
version = "0.2.0"
3+
version = "0.2.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -24,7 +24,7 @@ NormalizingFlowsCUDAExt = "CUDA"
2424
ADTypes = "1"
2525
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
2626
CUDA = "5"
27-
DifferentiationInterface = "0.6.42"
27+
DifferentiationInterface = "0.6, 0.7"
2828
Distributions = "0.25"
2929
DocStringExtensions = "0.9"
3030
Optimisers = "0.2.16, 0.3, 0.4"

docs/src/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ and hope to generate approximate samples from it.
6161
```@docs
6262
NormalizingFlows.elbo
6363
```
64+
65+
```@docs
66+
NormalizingFlows.elbo_batch
67+
```
68+
6469
#### Log-likelihood
6570

6671
By maximizing the log-likelihood, it is equivalent to minimizing the forward KL divergence between $q_\theta$ and $p$, i.e.,

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor
2121
To install the package, run the following command in the Julia REPL:
2222
```
2323
] # enter Pkg mode
24-
(@v1.9) pkg> add [email protected]:TuringLang/NormalizingFlows.jl.git
24+
(@v1.11) pkg> add [email protected]:TuringLang/NormalizingFlows.jl.git
2525
```
2626
Then simply run the following command to use the package:
2727
```julia

example/Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
5-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
7+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
78
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8-
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
13+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1214
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
1315
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1416
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
17+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1518
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16-
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
17-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
20+
21+
[extras]
22+
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

example/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ normalizing flow to approximate the target distribution using `NormalizingFlows.
1212
Currently, all examples share the same [Julia project](https://pkgdocs.julialang.org/v1/environments/#Using-someone-else's-project). To run the examples, first activate the project environment:
1313

1414
```julia
15-
# pwd() = "NormalizingFlows.jl/"
16-
using Pkg; Pkg.activate("example"); Pkg.instantiate()
15+
# pwd() = "NormalizingFlows.jl/example"
16+
using Pkg; Pkg.activate("."); Pkg.instantiate()
1717
```
18-
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with include("<example-to-run>.jl"), or by running the example script line-by-line.
18+
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with `include("<example-to-run>.jl")`, or by running the example script line-by-line.

example/SyntheticTargets.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using DocStringExtensions
2+
using Distributions, Random, LinearAlgebra
3+
using IrrationalConstants
4+
using Plots
5+
6+
7+
include("targets/banana.jl")
8+
include("targets/cross.jl")
9+
include("targets/neal_funnel.jl")
10+
include("targets/warped_gaussian.jl")
11+
12+
function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000))
13+
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100)
14+
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100)
15+
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange]
16+
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2)
17+
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright)
18+
return fig
19+
end

example/demo_RealNVP.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
using Flux
2+
using Bijectors
3+
using Bijectors: partition, combine, PartitionMask
4+
5+
using Random, Distributions, LinearAlgebra
6+
using Functors
7+
using Optimisers, ADTypes
8+
using Mooncake
9+
using NormalizingFlows
10+
11+
include("SyntheticTargets.jl")
12+
include("utils.jl")
13+
14+
##################################
15+
# define affine coupling layer using Bijectors.jl interface
16+
#################################
17+
struct AffineCoupling <: Bijectors.Bijector
18+
dim::Int
19+
mask::Bijectors.PartitionMask
20+
s::Flux.Chain
21+
t::Flux.Chain
22+
end
23+
24+
# let params track field s and t
25+
@functor AffineCoupling (s, t)
26+
27+
function AffineCoupling(
28+
dim::Int, # dimension of input
29+
hdims::Int, # dimension of hidden units for s and t
30+
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
31+
)
32+
cdims = length(mask_idx) # dimension of parts used to construct coupling law
33+
s = mlp3(cdims, hdims, cdims)
34+
t = mlp3(cdims, hdims, cdims)
35+
mask = PartitionMask(dim, mask_idx)
36+
return AffineCoupling(dim, mask, s, t)
37+
end
38+
39+
function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
40+
# partition vector using 'af.mask::PartitionMask`
41+
x₁, x₂, x₃ = partition(af.mask, x)
42+
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
43+
return combine(af.mask, y₁, x₂, x₃)
44+
end
45+
46+
function (af::AffineCoupling)(x::AbstractArray)
47+
return transform(af, x)
48+
end
49+
50+
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
51+
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
52+
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
53+
logjac = sum(log abs, af.s(x_2)) # this is a scalar
54+
return combine(af.mask, y_1, x_2, x_3), logjac
55+
end
56+
57+
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
58+
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
59+
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
60+
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
61+
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
62+
end
63+
64+
65+
function Bijectors.with_logabsdet_jacobian(
66+
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
67+
)
68+
af = iaf.orig
69+
# partition vector using `af.mask::PartitionMask`
70+
y_1, y_2, y_3 = partition(af.mask, y)
71+
# inverse transformation
72+
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
73+
logjac = -sum(log abs, af.s(y_2))
74+
return combine(af.mask, x_1, y_2, y_3), logjac
75+
end
76+
77+
function Bijectors.with_logabsdet_jacobian(
78+
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
79+
)
80+
af = iaf.orig
81+
# partition vector using `af.mask::PartitionMask`
82+
y_1, y_2, y_3 = partition(af.mask, y)
83+
# inverse transformation
84+
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
85+
logjac = -sum(log abs, af.s(y_2); dims = 1)
86+
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
87+
end
88+
89+
###################
90+
# an equivalent definition of AffineCoupling using Bijectors.Coupling
91+
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
92+
###################
93+
94+
# struct AffineCoupling <: Bijectors.Bijector
95+
# dim::Int
96+
# mask::Bijectors.PartitionMask
97+
# s::Flux.Chain
98+
# t::Flux.Chain
99+
# end
100+
101+
# # let params track field s and t
102+
# @functor AffineCoupling (s, t)
103+
104+
# function AffineCoupling(dim, mask, s, t)
105+
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
106+
# end
107+
108+
# function AffineCoupling(
109+
# dim::Int, # dimension of input
110+
# hdims::Int, # dimension of hidden units for s and t
111+
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
112+
# )
113+
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
114+
# s = mlp3(cdims, hdims, cdims)
115+
# t = mlp3(cdims, hdims, cdims)
116+
# mask = PartitionMask(dim, mask_idx)
117+
# return AffineCoupling(dim, mask, s, t)
118+
# end
119+
120+
121+
122+
##################################
123+
# start demo
124+
#################################
125+
Random.seed!(123)
126+
rng = Random.default_rng()
127+
T = Float32
128+
129+
######################################
130+
# a difficult banana target
131+
######################################
132+
target = Banana(2, 1.0f0, 100.0f0)
133+
logp = Base.Fix1(logpdf, target)
134+
135+
######################################
136+
# learn the target using Affine coupling flow
137+
######################################
138+
@leaf MvNormal
139+
q0 = MvNormal(zeros(T, 2), ones(T, 2))
140+
141+
d = 2
142+
hdims = 32
143+
144+
# alternating the coupling layers
145+
Ls = [AffineCoupling(d, hdims, [1]) AffineCoupling(d, hdims, [2]) for i in 1:3]
146+
147+
flow = create_flow(Ls, q0)
148+
flow_untrained = deepcopy(flow)
149+
150+
151+
######################################
152+
# start training
153+
######################################
154+
sample_per_iter = 64
155+
156+
# callback function to log training progress
157+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
158+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
159+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
160+
flow_trained, stats, _ = train_flow(
161+
rng,
162+
elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
163+
flow,
164+
logp,
165+
sample_per_iter;
166+
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
167+
optimiser=Optimisers.Adam(5e-4),
168+
ADbackend=adtype,
169+
show_progress=true,
170+
callback=cb,
171+
hasconverged=checkconv,
172+
)
173+
θ, re = Optimisers.destructure(flow_trained)
174+
losses = map(x -> x.loss, stats)
175+
176+
######################################
177+
# evaluate trained flow
178+
######################################
179+
plot(losses; label="Loss", linewidth=2) # plot the loss
180+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)

0 commit comments

Comments
 (0)