Skip to content

Commit 6872760

Browse files
committed
clean demos for realnvp/planar/radial/ fix a bug in nsf
1 parent d84fce4 commit 6872760

File tree

10 files changed

+155
-310
lines changed

10 files changed

+155
-310
lines changed

example/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1212
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1313
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1414
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
15+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1516
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1617
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
1718
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"

example/RealNVP/main.jl

Lines changed: 0 additions & 55 deletions
This file was deleted.

example/common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ end
7979
# return p
8080
# end
8181

82-
# function create_flow(Ls, q₀)
83-
# ts = fchain(Ls)
84-
# return transformed(q₀, ts)
85-
# end
82+
function create_flow(Ls, q₀)
83+
ts = reduce(, Ls)
84+
return transformed(q₀, ts)
85+
end
8686

8787
function visualize(p::Bijectors.MultivariateTransformed, samples=rand(p, 1000))
8888
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100)

example/RealNVP/AffineCoupling.jl renamed to example/demo_RealNVP.jl

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@ using Functors
33
using Bijectors
44
using Bijectors: partition, combine, PartitionMask
55

6-
include("../util.jl")
6+
using Random, Distributions, LinearAlgebra
7+
using Functors
8+
using Optimisers, ADTypes
9+
using Mooncake
10+
using NormalizingFlows
11+
12+
include("common.jl")
13+
include("SyntheticTargets.jl")
14+
include("nn.jl")
715

8-
"""
9-
Affinecoupling layer
10-
"""
16+
##################################
17+
# define affine coupling layer using Bijectors.jl interface
18+
#################################
1119
struct AffineCoupling <: Bijectors.Bijector
1220
dim::Int
1321
mask::Bijectors.PartitionMask
@@ -96,3 +104,62 @@ end
96104
# mask = PartitionMask(dim, mask_idx)
97105
# return AffineCoupling(dim, mask, s, t)
98106
# end
107+
108+
109+
110+
##################################
111+
# start demo
112+
#################################
113+
Random.seed!(123)
114+
rng = Random.default_rng()
115+
T = Float32
116+
117+
######################################
118+
# a difficult banana target
119+
######################################
120+
target = Banana(2, 1.0f0, 100.0f0)
121+
logp = Base.Fix1(logpdf, target)
122+
123+
######################################
124+
# learn the target using Affine coupling flow
125+
######################################
126+
@leaf MvNormal
127+
q0 = MvNormal(zeros(T, 2), ones(T, 2))
128+
129+
d = 2
130+
hdims = 32
131+
Ls = [AffineCoupling(d, hdims, [1]) AffineCoupling(d, hdims, [2]) for i in 1:3]
132+
133+
flow = create_flow(Ls, q0)
134+
flow_untrained = deepcopy(flow)
135+
136+
137+
######################################
138+
# start training
139+
######################################
140+
sample_per_iter = 64
141+
142+
# callback function to log training progress
143+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
144+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
145+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
146+
flow_trained, stats, _ = train_flow(
147+
elbo,
148+
flow,
149+
logp,
150+
sample_per_iter;
151+
max_iters=50_000,
152+
optimiser=Optimisers.Adam(5e-4),
153+
ADbackend=adtype,
154+
show_progress=true,
155+
callback=cb,
156+
hasconverged=checkconv,
157+
)
158+
θ, re = Optimisers.destructure(flow_trained)
159+
losses = map(x -> x.loss, stats)
160+
161+
######################################
162+
# evaluate trained flow
163+
######################################
164+
plot(losses; label="Loss", linewidth=2) # plot the loss
165+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)

example/neural_spline_flow/nsf_layer.jl renamed to example/demo_neural_spline_flow.jl

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,27 @@
11
using Flux
22
using Functors
33
using Bijectors
4-
using Bijectors: partition, PartitionMask
4+
using Bijectors: partition, combine, PartitionMask
55

6-
include("../util.jl")
6+
using Random, Distributions, LinearAlgebra
7+
using Functors
8+
using Optimisers, ADTypes
9+
using Mooncake
10+
using NormalizingFlows
11+
12+
include("common.jl")
13+
include("SyntheticTargets.jl")
14+
include("nn.jl")
15+
16+
##################################
17+
# define neural spline layer using Bijectors.jl interface
18+
#################################
719
"""
820
Neural Rational quadratic Spline layer
921
1022
# References
1123
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
1224
"""
13-
# struct NeuralSplineLayer{T1,T2,A<:AbstractVecOrMat{T1}} <: Bijectors.Bijector
14-
# dim::Int
15-
# mask::Bijectors.PartitionMask
16-
# w::A # width
17-
# h::A # height
18-
# d::A # derivative of the knots
19-
# B::T2 # bound of the knots
20-
# end
21-
22-
# function NeuralSplineLayer(
23-
# dim::Int, # dimension of input
24-
# hdims::Int, # dimension of hidden units for s and t
25-
# K::Int, # number of knots
26-
# B::T2, # bound of the knots
27-
# mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
28-
# ) where {T2<:Real}
29-
# num_of_transformed_dims = length(mask_idx)
30-
# input_dims = dim - num_of_transformed_dims
31-
# w = fill(MLP_3layer(input_dims, hdims, K), num_of_transformed_dims)
32-
# h = fill(MLP_3layer(input_dims, hdims, K), num_of_transformed_dims)
33-
# d = fill(MLP_3layer(input_dims, hdims, K - 1), num_of_transformed_dims)
34-
# mask = Bijectors.PartitionMask(dim, mask_idx)
35-
# return NeuralSplineLayer(dim, mask, w, h, d, B)
36-
# end
37-
38-
# @functor NeuralSplineLayer (w, h, d)
39-
40-
# # define forward and inverse transformation
41-
# function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
42-
# # instantiate rqs knots and derivatives
43-
# ws = permutedims(reduce(hcat, [w(x) for w in nsl.w]))
44-
# hs = permutedims(reduce(hcat, [h(x) for h in nsl.h]))
45-
# ds = permutedims(reduce(hcat, [d(x) for d in nsl.d]))
46-
# return Bijectors.RationalQuadraticSpline(ws, hs, ds, nsl.B)
47-
# end
48-
49-
## Question: which one is better, the struct below or the struct above?
5025
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
5126
dim::Int
5227
K::Int
@@ -64,7 +39,7 @@ function NeuralSplineLayer(
6439
) where {T1<:Int,T2<:Real}
6540
num_of_transformed_dims = length(mask_idx)
6641
input_dims = dim - num_of_transformed_dims
67-
nn = fill(MLP_3layer(input_dims, hdims, 3K - 1), num_of_transformed_dims)
42+
nn = [MLP_3layer(input_dims, hdims, 3K - 1) for _ in 1:num_of_transformed_dims]
6843
mask = Bijectors.PartitionMask(dim, mask_idx)
6944
return NeuralSplineLayer(dim, K, nn, B, mask)
7045
end
@@ -124,3 +99,67 @@ function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVe
12499
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
125100
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
126101
end
102+
103+
104+
105+
##################################
106+
# start demo
107+
#################################
108+
Random.seed!(123)
109+
rng = Random.default_rng()
110+
T = Float32
111+
112+
######################################
113+
# a difficult banana target
114+
######################################
115+
target = Banana(2, 1.0f0, 100.0f0)
116+
logp = Base.Fix1(logpdf, target)
117+
118+
######################################
119+
# learn the target using Affine coupling flow
120+
######################################
121+
@leaf MvNormal
122+
q0 = MvNormal(zeros(T, 2), ones(T, 2))
123+
124+
d = 2
125+
hdims = 32
126+
K = 8
127+
B = 3
128+
Ls = [
129+
NeuralSplineLayer(d, hdims, K, B, [1]) NeuralSplineLayer(d, hdims, K, B, [2]) for
130+
i in 1:3
131+
]
132+
133+
flow = create_flow(Ls, q0)
134+
flow_untrained = deepcopy(flow)
135+
136+
137+
######################################
138+
# start training
139+
######################################
140+
sample_per_iter = 64
141+
142+
# callback function to log training progress
143+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
144+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
145+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
146+
flow_trained, stats, _ = train_flow(
147+
elbo,
148+
flow,
149+
logp,
150+
sample_per_iter;
151+
max_iters=50_000,
152+
optimiser=Optimisers.Adam(5e-4),
153+
ADbackend=adtype,
154+
show_progress=true,
155+
callback=cb,
156+
hasconverged=checkconv,
157+
)
158+
θ, re = Optimisers.destructure(flow_trained)
159+
losses = map(x -> x.loss, stats)
160+
161+
######################################
162+
# evaluate trained flow
163+
######################################
164+
plot(losses; label="Loss", linewidth=2) # plot the loss
165+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)