Skip to content

Commit 55fb607

Browse files
committed
update demos, debugging mooncake with elbo
1 parent 1dfebd1 commit 55fb607

File tree

4 files changed

+10
-218
lines changed

4 files changed

+10
-218
lines changed

example/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1717
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
20+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2021

2122
[extras]
2223
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

example/demo_RealNVP.jl

Lines changed: 8 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -11,114 +11,6 @@ using NormalizingFlows
1111
include("SyntheticTargets.jl")
1212
include("utils.jl")
1313

14-
##################################
15-
# define affine coupling layer using Bijectors.jl interface
16-
#################################
17-
struct AffineCoupling <: Bijectors.Bijector
18-
dim::Int
19-
mask::Bijectors.PartitionMask
20-
s::Flux.Chain
21-
t::Flux.Chain
22-
end
23-
24-
# let params track field s and t
25-
@functor AffineCoupling (s, t)
26-
27-
function AffineCoupling(
28-
dim::Int, # dimension of input
29-
hdims::Int, # dimension of hidden units for s and t
30-
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
31-
)
32-
cdims = length(mask_idx) # dimension of parts used to construct coupling law
33-
s = mlp3(cdims, hdims, cdims)
34-
t = mlp3(cdims, hdims, cdims)
35-
mask = PartitionMask(dim, mask_idx)
36-
return AffineCoupling(dim, mask, s, t)
37-
end
38-
39-
function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
40-
# partition vector using 'af.mask::PartitionMask`
41-
x₁, x₂, x₃ = partition(af.mask, x)
42-
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
43-
return combine(af.mask, y₁, x₂, x₃)
44-
end
45-
46-
function (af::AffineCoupling)(x::AbstractArray)
47-
return transform(af, x)
48-
end
49-
50-
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
51-
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
52-
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
53-
logjac = sum(log abs, af.s(x_2)) # this is a scalar
54-
return combine(af.mask, y_1, x_2, x_3), logjac
55-
end
56-
57-
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
58-
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
59-
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
60-
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
61-
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
62-
end
63-
64-
65-
function Bijectors.with_logabsdet_jacobian(
66-
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
67-
)
68-
af = iaf.orig
69-
# partition vector using `af.mask::PartitionMask`
70-
y_1, y_2, y_3 = partition(af.mask, y)
71-
# inverse transformation
72-
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
73-
logjac = -sum(log abs, af.s(y_2))
74-
return combine(af.mask, x_1, y_2, y_3), logjac
75-
end
76-
77-
function Bijectors.with_logabsdet_jacobian(
78-
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
79-
)
80-
af = iaf.orig
81-
# partition vector using `af.mask::PartitionMask`
82-
y_1, y_2, y_3 = partition(af.mask, y)
83-
# inverse transformation
84-
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
85-
logjac = -sum(log abs, af.s(y_2); dims = 1)
86-
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
87-
end
88-
89-
###################
90-
# an equivalent definition of AffineCoupling using Bijectors.Coupling
91-
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
92-
###################
93-
94-
# struct AffineCoupling <: Bijectors.Bijector
95-
# dim::Int
96-
# mask::Bijectors.PartitionMask
97-
# s::Flux.Chain
98-
# t::Flux.Chain
99-
# end
100-
101-
# # let params track field s and t
102-
# @functor AffineCoupling (s, t)
103-
104-
# function AffineCoupling(dim, mask, s, t)
105-
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
106-
# end
107-
108-
# function AffineCoupling(
109-
# dim::Int, # dimension of input
110-
# hdims::Int, # dimension of hidden units for s and t
111-
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
112-
# )
113-
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
114-
# s = mlp3(cdims, hdims, cdims)
115-
# t = mlp3(cdims, hdims, cdims)
116-
# mask = PartitionMask(dim, mask_idx)
117-
# return AffineCoupling(dim, mask, s, t)
118-
# end
119-
120-
121-
12214
##################################
12315
# start demo
12416
#################################
@@ -132,29 +24,30 @@ T = Float32
13224
target = Banana(2, 1.0f0, 100.0f0)
13325
logp = Base.Fix1(logpdf, target)
13426

27+
13528
######################################
13629
# learn the target using Affine coupling flow
13730
######################################
13831
@leaf MvNormal
139-
q0 = MvNormal(zeros(T, 2), ones(T, 2))
32+
q0 = MvNormal(zeros(T, 2), I)
14033

14134
d = 2
142-
hdims = 32
143-
144-
# alternating the coupling layers
145-
Ls = [AffineCoupling(d, hdims, [1]) AffineCoupling(d, hdims, [2]) for i in 1:3]
35+
hdims = [16, 16]
36+
nlayers = 3
14637

147-
flow = create_flow(Ls, q0)
38+
# use NormalizingFlows.realnvp to create a RealNVP flow
39+
flow = realnvp(q0, hdims, nlayers; paramtype=T)
14840
flow_untrained = deepcopy(flow)
14941

15042

15143
######################################
15244
# start training
15345
######################################
154-
sample_per_iter = 64
46+
sample_per_iter = 16
15547

15648
# callback function to log training progress
15749
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
50+
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
15851
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
15952
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
16053
flow_trained, stats, _ = train_flow(

example/demo_neural_spline_flow.jl

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -11,104 +11,6 @@ using NormalizingFlows
1111
include("SyntheticTargets.jl")
1212
include("utils.jl")
1313

14-
##################################
15-
# define neural spline layer using Bijectors.jl interface
16-
#################################
17-
"""
18-
Neural Rational quadratic Spline layer
19-
20-
# References
21-
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
22-
"""
23-
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
24-
dim::Int # dimension of input
25-
K::Int # number of knots
26-
n_dims_transferred::Int # number of dimensions that are transformed
27-
nn::A # networks that parmaterize the knots and derivatives
28-
B::T # bound of the knots
29-
mask::Bijectors.PartitionMask
30-
end
31-
32-
function NeuralSplineLayer(
33-
dim::T1, # dimension of input
34-
hdims::T1, # dimension of hidden units for s and t
35-
K::T1, # number of knots
36-
B::T2, # bound of the knots
37-
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
38-
) where {T1<:Int,T2<:Real}
39-
num_of_transformed_dims = length(mask_idx)
40-
input_dims = dim - num_of_transformed_dims
41-
42-
# output dim of the NN
43-
output_dims = (3K - 1)*num_of_transformed_dims
44-
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
45-
nn = mlp3(input_dims, hdims, output_dims)
46-
47-
mask = Bijectors.PartitionMask(dim, mask_idx)
48-
return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask)
49-
end
50-
51-
@functor NeuralSplineLayer (nn,)
52-
53-
# define forward and inverse transformation
54-
"""
55-
Build a rational quadratic spline from the nn output
56-
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
57-
58-
we just need to map the nn output to the knots and derivatives of the RQS
59-
"""
60-
function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
61-
K, B = nsl.K, nsl.B
62-
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
63-
ws = @view nnoutput[:, 1:K]
64-
hs = @view nnoutput[:, (K + 1):(2K)]
65-
ds = @view nnoutput[:, (2K + 1):(3K - 1)]
66-
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
67-
end
68-
69-
function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector)
70-
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
71-
# instantiate rqs knots and derivatives
72-
rqs = instantiate_rqs(nsl, x_2)
73-
y_1 = Bijectors.transform(rqs, x_1)
74-
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
75-
end
76-
77-
function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
78-
nsl = insl.orig
79-
y1, y2, y3 = partition(nsl.mask, y)
80-
rqs = instantiate_rqs(nsl, y2)
81-
x1 = Bijectors.transform(Inverse(rqs), y1)
82-
return Bijectors.combine(nsl.mask, x1, y2, y3)
83-
end
84-
85-
function (nsl::NeuralSplineLayer)(x::AbstractVector)
86-
return Bijectors.transform(nsl, x)
87-
end
88-
89-
# define logabsdetjac
90-
function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector)
91-
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
92-
rqs = instantiate_rqs(nsl, x_2)
93-
logjac = logabsdetjac(rqs, x_1)
94-
return logjac
95-
end
96-
97-
function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
98-
nsl = insl.orig
99-
y1, y2, _ = partition(nsl.mask, y)
100-
rqs = instantiate_rqs(nsl, y2)
101-
logjac = logabsdetjac(Inverse(rqs), y1)
102-
return logjac
103-
end
104-
105-
function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector)
106-
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
107-
rqs = instantiate_rqs(nsl, x_2)
108-
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
109-
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
110-
end
111-
11214
##################################
11315
# start demo
11416
#################################
@@ -148,6 +50,7 @@ sample_per_iter = 64
14850

14951
# callback function to log training progress
15052
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
53+
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
15154
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
15255
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
15356
flow_trained, stats, _ = train_flow(

example/utils.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux
1313
)
1414
end
1515

16-
function create_flow(Ls, q₀)
17-
ts = reduce(, Ls)
18-
return transformed(q₀, ts)
19-
end
20-
2116
function compare_trained_and_untrained_flow(
2217
flow_trained::Bijectors.MultivariateTransformed,
2318
flow_untrained::Bijectors.MultivariateTransformed,

0 commit comments

Comments
 (0)