Skip to content

Commit d84fce4

Browse files
committed
restructure example folder and refactor planar and radial examples
1 parent d916b64 commit d84fce4

File tree

7 files changed

+102
-138
lines changed

7 files changed

+102
-138
lines changed

example/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
1111
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1212
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1313
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
14+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1415
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1516
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
1617
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -19,6 +20,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1920
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
2021
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
2122
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
23+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2224
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
2325
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2426
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

example/SyntheticTargets.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
using DocStringExtensions
12
using Distributions, Random, LinearAlgebra
2-
using LogDensityProblems
33
using IrrationalConstants
44
using Plots
55

@@ -23,7 +23,3 @@ function load_model(name::String)
2323
error("Model not defined")
2424
end
2525
end
26-
27-
LogDensityProblems.dimension(dist::ContinuousDistribution) = length(dist)
28-
LogDensityProblems.logdensity(dist::ContinuousDistribution, x) = logpdf(dist, x)
29-

example/common.jl

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Random, Distributions, LinearAlgebra, Bijectors
2-
include("util.jl")
32

43
function compare_trained_and_untrained_flow(
54
flow_trained::Bijectors.MultivariateTransformed,
@@ -80,42 +79,16 @@ end
8079
# return p
8180
# end
8281

83-
function create_flow(Ls, q₀)
84-
ts = fchain(Ls)
85-
return transformed(q₀, ts)
86-
end
87-
88-
#######################
89-
# training function for InvertibleNetworks
90-
########################
91-
92-
# function pm_next!(pm, stats::NamedTuple)
93-
# return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
82+
# function create_flow(Ls, q₀)
83+
# ts = fchain(Ls)
84+
# return transformed(q₀, ts)
9485
# end
9586

96-
# function train_invertible_networks!(G, loss, data_loader, n_epoch, opt)
97-
# max_iters = n_epoch * length(data_loader)
98-
# prog = ProgressMeter.Progress(
99-
# max_iters; desc="Training", barlen=31, showspeed=true, enabled=true
100-
# )
101-
102-
# nnls = []
103-
104-
# # training loop
105-
# time_elapsed = @elapsed for (i, xs) in enumerate(IterTools.ncycle(data_loader, n_epoch))
106-
# ls = loss(G, xs) #sets gradients of G
107-
108-
# push!(nnls, ls)
109-
110-
# grad_norm = 0
111-
# for p in get_params(G)
112-
# grad_norm += sum(abs2, p.grad)
113-
# Flux.update!(opt, p.data, p.grad)
114-
# end
115-
# grad_norm = sqrt(grad_norm)
116-
117-
# stat = (iteration=i, neg_log_llh=ls, gradient_norm=grad_norm)
118-
# pm_next!(prog, stat)
119-
# end
120-
# return nnls
121-
# end
87+
function visualize(p::Bijectors.MultivariateTransformed, samples=rand(p, 1000))
88+
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100)
89+
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100)
90+
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange]
91+
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2)
92+
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright)
93+
return fig
94+
end

example/util.jl renamed to example/nn.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
1-
using LinearAlgebra
2-
using StatsBase
3-
using BlockBandedMatrices
41
using Flux, Bijectors
5-
using Base.Threads
6-
using DiffResults
7-
using ForwardDiff
8-
using ProgressMeter
9-
using TickTock
102

113
function MLP_3layer(input_dim::Int, hdims::Int, output_dim::Int; activation=Flux.leakyrelu)
124
return Chain(
@@ -55,12 +47,3 @@ end
5547
# mlp_layer = MLP_BN(inputdim, hdim, outputdim; activation=activation)
5648
# return Flux.SkipConnection(mlp_layer, +)
5749
# end
58-
59-
function rand_batch(rng::AbstractRNG, td::Bijectors.MvTransformed, num_samples::Int)
60-
samples = rand(rng, td.dist, num_samples)
61-
res = td.transform(samples)
62-
return res
63-
end
64-
function rand_batch(td::Bijectors.MvTransformed, num_samples::Int)
65-
return rand_batch(Random.default_rng(), td, num_samples)
66-
end
Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,64 @@
11
using Random, Distributions, LinearAlgebra, Bijectors
2-
using ADTypes
3-
using Optimisers
2+
using Functors
3+
using Optimisers, ADTypes, Mooncake
44
using NormalizingFlows
5-
using Mooncake
6-
using CUDA
7-
using Flux: f32
8-
using Flux
9-
using Plots
105
include("common.jl")
6+
include("SyntheticTargets.jl")
117

128
Random.seed!(123)
139
rng = Random.default_rng()
14-
T = Float32
10+
T = Float64
1511

1612
######################################
1713
# 2d Banana as the target distribution
1814
######################################
19-
include("targets/banana.jl")
15+
target = load_model("Banana")
16+
logp = Base.Fix1(logpdf, target)
2017

21-
# create target p
22-
p = Banana(2, 1.0f-1, 100.0f0)
23-
logp = Base.Fix1(logpdf, p)
2418

2519
######################################
26-
# learn the target using planar flow
20+
# setup planar flow
2721
######################################
2822
function create_planar_flow(n_layers::Int, q₀)
2923
d = length(q₀)
30-
Ls = [f32(PlanarLayer(d)) for _ in 1:n_layers]
24+
Ls = [PlanarLayer(d) for _ in 1:n_layers]
3125
ts = reduce(, Ls)
3226
return transformed(q₀, ts)
3327
end
3428

35-
# create a 10-layer planar flow
29+
@leaf MvNormal
3630
q0 = MvNormal(zeros(T, 2), ones(T, 2))
3731
flow = create_planar_flow(10, q0)
32+
flow_untrained = deepcopy(flow)
3833

3934

4035

41-
# train the flow
42-
sample_per_iter = 10
43-
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,)
44-
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3
36+
######################################
37+
# start training
38+
######################################
39+
sample_per_iter = 30
40+
41+
# callback function to log training progress
42+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
43+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
44+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
4545
flow_trained, stats, _ = train_flow(
4646
elbo,
4747
flow,
4848
logp,
4949
sample_per_iter;
50-
max_iters=200_00,
51-
optimiser=Optimisers.Adam(),
50+
max_iters=10_000,
51+
optimiser=Optimisers.Adam(one(T)/100),
52+
ADbackend=adtype,
53+
show_progress=true,
5254
callback=cb,
53-
ADbackend=AutoZygote(),
5455
hasconverged=checkconv,
5556
)
57+
θ, re = Optimisers.destructure(flow_trained)
5658
losses = map(x -> x.loss, stats)
5759

5860
######################################
5961
# evaluate trained flow
6062
######################################
6163
plot(losses; label="Loss", linewidth=2) # plot the loss
62-
compare_trained_and_untrained_flow(flow_trained, flow_untrained, p, 1000)
64+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)

example/planar_radial_flow/radial_flow_main.jl

Lines changed: 0 additions & 55 deletions
This file was deleted.

example/radial_flow_demo.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Random, Distributions, LinearAlgebra, Bijectors
2+
using Functors
3+
using Optimisers, ADTypes, Mooncake
4+
using NormalizingFlows
5+
include("common.jl")
6+
include("SyntheticTargets.jl")
7+
8+
Random.seed!(123)
9+
rng = Random.default_rng()
10+
T = Float64
11+
12+
######################################
13+
# get target logp
14+
######################################
15+
target = load_model("WarpedGaussian")
16+
logp = Base.Fix1(logpdf, target)
17+
18+
######################################
19+
# setup radial flow
20+
######################################
21+
function create_radial_flow(n_layers::Int, q₀)
22+
d = length(q₀)
23+
Ls = [RadialLayer(d) for _ in 1:n_layers]
24+
ts = reduce(, Ls)
25+
return transformed(q₀, ts)
26+
end
27+
28+
# create a 10-layer radial flow
29+
@leaf MvNormal
30+
q0 = MvNormal(zeros(T, 2), ones(T, 2))
31+
flow = create_radial_flow(10, q0)
32+
33+
flow_untrained = deepcopy(flow)
34+
35+
######################################
36+
# start training
37+
######################################
38+
sample_per_iter = 30
39+
40+
# callback function to log training progress
41+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
42+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
43+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
44+
flow_trained, stats, _ = train_flow(
45+
elbo,
46+
flow,
47+
logp,
48+
sample_per_iter;
49+
max_iters=10_000,
50+
optimiser=Optimisers.Adam(one(T)/100),
51+
ADbackend=adtype,
52+
show_progress=true,
53+
callback=cb,
54+
hasconverged=checkconv,
55+
)
56+
θ, re = Optimisers.destructure(flow_trained)
57+
losses = map(x -> x.loss, stats)
58+
59+
######################################
60+
# evaluate trained flow
61+
######################################
62+
plot(losses; label="Loss", linewidth=2) # plot the loss
63+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)

0 commit comments

Comments
 (0)