@@ -11,114 +11,6 @@ using NormalizingFlows
1111include (" SyntheticTargets.jl" )
1212include (" utils.jl" )
1313
14- # #################################
15- # define affine coupling layer using Bijectors.jl interface
16- # ################################
17- struct AffineCoupling <: Bijectors.Bijector
18- dim:: Int
19- mask:: Bijectors.PartitionMask
20- s:: Flux.Chain
21- t:: Flux.Chain
22- end
23-
24- # let params track field s and t
25- @functor AffineCoupling (s, t)
26-
27- function AffineCoupling (
28- dim:: Int , # dimension of input
29- hdims:: Int , # dimension of hidden units for s and t
30- mask_idx:: AbstractVector , # index of dimensione that one wants to apply transformations on
31- )
32- cdims = length (mask_idx) # dimension of parts used to construct coupling law
33- s = mlp3 (cdims, hdims, cdims)
34- t = mlp3 (cdims, hdims, cdims)
35- mask = PartitionMask (dim, mask_idx)
36- return AffineCoupling (dim, mask, s, t)
37- end
38-
39- function Bijectors. transform (af:: AffineCoupling , x:: AbstractVecOrMat )
40- # partition vector using 'af.mask::PartitionMask`
41- x₁, x₂, x₃ = partition (af. mask, x)
42- y₁ = x₁ .* af. s (x₂) .+ af. t (x₂)
43- return combine (af. mask, y₁, x₂, x₃)
44- end
45-
46- function (af:: AffineCoupling )(x:: AbstractArray )
47- return transform (af, x)
48- end
49-
50- function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractVector )
51- x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
52- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
53- logjac = sum (log ∘ abs, af. s (x_2)) # this is a scalar
54- return combine (af. mask, y_1, x_2, x_3), logjac
55- end
56-
57- function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractMatrix )
58- x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
59- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
60- logjac = sum (log ∘ abs, af. s (x_2); dims = 1 ) # 1 × size(x, 2)
61- return combine (af. mask, y_1, x_2, x_3), vec (logjac)
62- end
63-
64-
65- function Bijectors. with_logabsdet_jacobian (
66- iaf:: Inverse{<:AffineCoupling} , y:: AbstractVector
67- )
68- af = iaf. orig
69- # partition vector using `af.mask::PartitionMask`
70- y_1, y_2, y_3 = partition (af. mask, y)
71- # inverse transformation
72- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
73- logjac = - sum (log ∘ abs, af. s (y_2))
74- return combine (af. mask, x_1, y_2, y_3), logjac
75- end
76-
77- function Bijectors. with_logabsdet_jacobian (
78- iaf:: Inverse{<:AffineCoupling} , y:: AbstractMatrix
79- )
80- af = iaf. orig
81- # partition vector using `af.mask::PartitionMask`
82- y_1, y_2, y_3 = partition (af. mask, y)
83- # inverse transformation
84- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
85- logjac = - sum (log ∘ abs, af. s (y_2); dims = 1 )
86- return combine (af. mask, x_1, y_2, y_3), vec (logjac)
87- end
88-
89- # ##################
90- # an equivalent definition of AffineCoupling using Bijectors.Coupling
91- # (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
92- # ##################
93-
94- # struct AffineCoupling <: Bijectors.Bijector
95- # dim::Int
96- # mask::Bijectors.PartitionMask
97- # s::Flux.Chain
98- # t::Flux.Chain
99- # end
100-
101- # # let params track field s and t
102- # @functor AffineCoupling (s, t)
103-
104- # function AffineCoupling(dim, mask, s, t)
105- # return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
106- # end
107-
108- # function AffineCoupling(
109- # dim::Int, # dimension of input
110- # hdims::Int, # dimension of hidden units for s and t
111- # mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
112- # )
113- # cdims = length(mask_idx) # dimension of parts used to construct coupling law
114- # s = mlp3(cdims, hdims, cdims)
115- # t = mlp3(cdims, hdims, cdims)
116- # mask = PartitionMask(dim, mask_idx)
117- # return AffineCoupling(dim, mask, s, t)
118- # end
119-
120-
121-
12214# #################################
12315# start demo
12416# ################################
@@ -132,29 +24,30 @@ T = Float32
13224target = Banana (2 , 1.0f0 , 100.0f0 )
13325logp = Base. Fix1 (logpdf, target)
13426
27+
13528# #####################################
13629# learn the target using Affine coupling flow
13730# #####################################
13831@leaf MvNormal
139- q0 = MvNormal (zeros (T, 2 ), ones (T, 2 ) )
32+ q0 = MvNormal (zeros (T, 2 ), I )
14033
14134d = 2
142- hdims = 32
143-
144- # alternating the coupling layers
145- Ls = [AffineCoupling (d, hdims, [1 ]) ∘ AffineCoupling (d, hdims, [2 ]) for i in 1 : 3 ]
35+ hdims = [16 , 16 ]
36+ nlayers = 3
14637
147- flow = create_flow (Ls, q0)
38+ # use NormalizingFlows.realnvp to create a RealNVP flow
39+ flow = realnvp (q0, hdims, nlayers; paramtype= T)
14840flow_untrained = deepcopy (flow)
14941
15042
15143# #####################################
15244# start training
15345# #####################################
154- sample_per_iter = 64
46+ sample_per_iter = 16
15547
15648# callback function to log training progress
15749cb (iter, opt_stats, re, θ) = (sample_per_iter= sample_per_iter,ad= adtype)
50+ # TODO : now using AutoMooncake the example broke, but AutoZygote works, need to debug
15851adtype = ADTypes. AutoMooncake (; config = Mooncake. Config ())
15952checkconv (iter, stat, re, θ, st) = stat. gradient_norm < one (T)/ 1000
16053flow_trained, stats, _ = train_flow (
0 commit comments