1
1
using Flux
2
2
using Functors
3
3
using Bijectors
4
- using Bijectors: partition, PartitionMask
4
+ using Bijectors: partition, combine, PartitionMask
5
5
6
- include (" ../util.jl" )
6
+ using Random, Distributions, LinearAlgebra
7
+ using Functors
8
+ using Optimisers, ADTypes
9
+ using Mooncake
10
+ using NormalizingFlows
11
+
12
+ include (" common.jl" )
13
+ include (" SyntheticTargets.jl" )
14
+ include (" nn.jl" )
15
+
16
+ # #################################
17
+ # define neural spline layer using Bijectors.jl interface
18
+ # ################################
7
19
"""
8
20
Neural Rational quadratic Spline layer
9
21
10
22
# References
11
23
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
12
24
"""
13
- # struct NeuralSplineLayer{T1,T2,A<:AbstractVecOrMat{T1}} <: Bijectors.Bijector
14
- # dim::Int
15
- # mask::Bijectors.PartitionMask
16
- # w::A # width
17
- # h::A # height
18
- # d::A # derivative of the knots
19
- # B::T2 # bound of the knots
20
- # end
21
-
22
- # function NeuralSplineLayer(
23
- # dim::Int, # dimension of input
24
- # hdims::Int, # dimension of hidden units for s and t
25
- # K::Int, # number of knots
26
- # B::T2, # bound of the knots
27
- # mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
28
- # ) where {T2<:Real}
29
- # num_of_transformed_dims = length(mask_idx)
30
- # input_dims = dim - num_of_transformed_dims
31
- # w = fill(MLP_3layer(input_dims, hdims, K), num_of_transformed_dims)
32
- # h = fill(MLP_3layer(input_dims, hdims, K), num_of_transformed_dims)
33
- # d = fill(MLP_3layer(input_dims, hdims, K - 1), num_of_transformed_dims)
34
- # mask = Bijectors.PartitionMask(dim, mask_idx)
35
- # return NeuralSplineLayer(dim, mask, w, h, d, B)
36
- # end
37
-
38
- # @functor NeuralSplineLayer (w, h, d)
39
-
40
- # # define forward and inverse transformation
41
- # function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
42
- # # instantiate rqs knots and derivatives
43
- # ws = permutedims(reduce(hcat, [w(x) for w in nsl.w]))
44
- # hs = permutedims(reduce(hcat, [h(x) for h in nsl.h]))
45
- # ds = permutedims(reduce(hcat, [d(x) for d in nsl.d]))
46
- # return Bijectors.RationalQuadraticSpline(ws, hs, ds, nsl.B)
47
- # end
48
-
49
- # # Question: which one is better, the struct below or the struct above?
50
25
struct NeuralSplineLayer{T,A<: Flux.Chain } <: Bijectors.Bijector
51
26
dim:: Int
52
27
K:: Int
@@ -64,7 +39,7 @@ function NeuralSplineLayer(
64
39
) where {T1<: Int ,T2<: Real }
65
40
num_of_transformed_dims = length (mask_idx)
66
41
input_dims = dim - num_of_transformed_dims
67
- nn = fill ( MLP_3layer (input_dims, hdims, 3 K - 1 ), num_of_transformed_dims)
42
+ nn = [ MLP_3layer (input_dims, hdims, 3 K - 1 ) for _ in 1 : num_of_transformed_dims]
68
43
mask = Bijectors. PartitionMask (dim, mask_idx)
69
44
return NeuralSplineLayer (dim, K, nn, B, mask)
70
45
end
@@ -124,3 +99,67 @@ function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVe
124
99
y_1, logjac = with_logabsdet_jacobian (rqs, x_1)
125
100
return Bijectors. combine (nsl. mask, y_1, x_2, x_3), logjac
126
101
end
102
+
103
+
104
+
105
+ # #################################
106
+ # start demo
107
+ # ################################
108
+ Random. seed! (123 )
109
+ rng = Random. default_rng ()
110
+ T = Float32
111
+
112
+ # #####################################
113
+ # a difficult banana target
114
+ # #####################################
115
+ target = Banana (2 , 1.0f0 , 100.0f0 )
116
+ logp = Base. Fix1 (logpdf, target)
117
+
118
+ # #####################################
119
+ # learn the target using Affine coupling flow
120
+ # #####################################
121
+ @leaf MvNormal
122
+ q0 = MvNormal (zeros (T, 2 ), ones (T, 2 ))
123
+
124
+ d = 2
125
+ hdims = 32
126
+ K = 8
127
+ B = 3
128
+ Ls = [
129
+ NeuralSplineLayer (d, hdims, K, B, [1 ]) ∘ NeuralSplineLayer (d, hdims, K, B, [2 ]) for
130
+ i in 1 : 3
131
+ ]
132
+
133
+ flow = create_flow (Ls, q0)
134
+ flow_untrained = deepcopy (flow)
135
+
136
+
137
+ # #####################################
138
+ # start training
139
+ # #####################################
140
+ sample_per_iter = 64
141
+
142
+ # callback function to log training progress
143
+ cb (iter, opt_stats, re, θ) = (sample_per_iter= sample_per_iter,ad= adtype)
144
+ adtype = ADTypes. AutoMooncake (; config = Mooncake. Config ())
145
+ checkconv (iter, stat, re, θ, st) = stat. gradient_norm < one (T)/ 1000
146
+ flow_trained, stats, _ = train_flow (
147
+ elbo,
148
+ flow,
149
+ logp,
150
+ sample_per_iter;
151
+ max_iters= 50_000 ,
152
+ optimiser= Optimisers. Adam (5e-4 ),
153
+ ADbackend= adtype,
154
+ show_progress= true ,
155
+ callback= cb,
156
+ hasconverged= checkconv,
157
+ )
158
+ θ, re = Optimisers. destructure (flow_trained)
159
+ losses = map (x -> x. loss, stats)
160
+
161
+ # #####################################
162
+ # evaluate trained flow
163
+ # #####################################
164
+ plot (losses; label= " Loss" , linewidth= 2 ) # plot the loss
165
+ compare_trained_and_untrained_flow (flow_trained, flow_untrained, target, 1000 )
0 commit comments