@@ -11,114 +11,6 @@ using NormalizingFlows
11
11
include (" SyntheticTargets.jl" )
12
12
include (" utils.jl" )
13
13
14
- # #################################
15
- # define affine coupling layer using Bijectors.jl interface
16
- # ################################
17
- struct AffineCoupling <: Bijectors.Bijector
18
- dim:: Int
19
- mask:: Bijectors.PartitionMask
20
- s:: Flux.Chain
21
- t:: Flux.Chain
22
- end
23
-
24
- # let params track field s and t
25
- @functor AffineCoupling (s, t)
26
-
27
- function AffineCoupling (
28
- dim:: Int , # dimension of input
29
- hdims:: Int , # dimension of hidden units for s and t
30
- mask_idx:: AbstractVector , # index of dimensione that one wants to apply transformations on
31
- )
32
- cdims = length (mask_idx) # dimension of parts used to construct coupling law
33
- s = mlp3 (cdims, hdims, cdims)
34
- t = mlp3 (cdims, hdims, cdims)
35
- mask = PartitionMask (dim, mask_idx)
36
- return AffineCoupling (dim, mask, s, t)
37
- end
38
-
39
- function Bijectors. transform (af:: AffineCoupling , x:: AbstractVecOrMat )
40
- # partition vector using 'af.mask::PartitionMask`
41
- x₁, x₂, x₃ = partition (af. mask, x)
42
- y₁ = x₁ .* af. s (x₂) .+ af. t (x₂)
43
- return combine (af. mask, y₁, x₂, x₃)
44
- end
45
-
46
- function (af:: AffineCoupling )(x:: AbstractArray )
47
- return transform (af, x)
48
- end
49
-
50
- function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractVector )
51
- x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
52
- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
53
- logjac = sum (log ∘ abs, af. s (x_2)) # this is a scalar
54
- return combine (af. mask, y_1, x_2, x_3), logjac
55
- end
56
-
57
- function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractMatrix )
58
- x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
59
- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
60
- logjac = sum (log ∘ abs, af. s (x_2); dims = 1 ) # 1 × size(x, 2)
61
- return combine (af. mask, y_1, x_2, x_3), vec (logjac)
62
- end
63
-
64
-
65
- function Bijectors. with_logabsdet_jacobian (
66
- iaf:: Inverse{<:AffineCoupling} , y:: AbstractVector
67
- )
68
- af = iaf. orig
69
- # partition vector using `af.mask::PartitionMask`
70
- y_1, y_2, y_3 = partition (af. mask, y)
71
- # inverse transformation
72
- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
73
- logjac = - sum (log ∘ abs, af. s (y_2))
74
- return combine (af. mask, x_1, y_2, y_3), logjac
75
- end
76
-
77
- function Bijectors. with_logabsdet_jacobian (
78
- iaf:: Inverse{<:AffineCoupling} , y:: AbstractMatrix
79
- )
80
- af = iaf. orig
81
- # partition vector using `af.mask::PartitionMask`
82
- y_1, y_2, y_3 = partition (af. mask, y)
83
- # inverse transformation
84
- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
85
- logjac = - sum (log ∘ abs, af. s (y_2); dims = 1 )
86
- return combine (af. mask, x_1, y_2, y_3), vec (logjac)
87
- end
88
-
89
- # ##################
90
- # an equivalent definition of AffineCoupling using Bijectors.Coupling
91
- # (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
92
- # ##################
93
-
94
- # struct AffineCoupling <: Bijectors.Bijector
95
- # dim::Int
96
- # mask::Bijectors.PartitionMask
97
- # s::Flux.Chain
98
- # t::Flux.Chain
99
- # end
100
-
101
- # # let params track field s and t
102
- # @functor AffineCoupling (s, t)
103
-
104
- # function AffineCoupling(dim, mask, s, t)
105
- # return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
106
- # end
107
-
108
- # function AffineCoupling(
109
- # dim::Int, # dimension of input
110
- # hdims::Int, # dimension of hidden units for s and t
111
- # mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
112
- # )
113
- # cdims = length(mask_idx) # dimension of parts used to construct coupling law
114
- # s = mlp3(cdims, hdims, cdims)
115
- # t = mlp3(cdims, hdims, cdims)
116
- # mask = PartitionMask(dim, mask_idx)
117
- # return AffineCoupling(dim, mask, s, t)
118
- # end
119
-
120
-
121
-
122
14
# #################################
123
15
# start demo
124
16
# ################################
@@ -132,29 +24,30 @@ T = Float32
132
24
target = Banana (2 , 1.0f0 , 100.0f0 )
133
25
logp = Base. Fix1 (logpdf, target)
134
26
27
+
135
28
# #####################################
136
29
# learn the target using Affine coupling flow
137
30
# #####################################
138
31
@leaf MvNormal
139
- q0 = MvNormal (zeros (T, 2 ), ones (T, 2 ) )
32
+ q0 = MvNormal (zeros (T, 2 ), I )
140
33
141
34
d = 2
142
- hdims = 32
143
-
144
- # alternating the coupling layers
145
- Ls = [AffineCoupling (d, hdims, [1 ]) ∘ AffineCoupling (d, hdims, [2 ]) for i in 1 : 3 ]
35
+ hdims = [16 , 16 ]
36
+ nlayers = 3
146
37
147
- flow = create_flow (Ls, q0)
38
+ # use NormalizingFlows.realnvp to create a RealNVP flow
39
+ flow = realnvp (q0, hdims, nlayers; paramtype= T)
148
40
flow_untrained = deepcopy (flow)
149
41
150
42
151
43
# #####################################
152
44
# start training
153
45
# #####################################
154
- sample_per_iter = 64
46
+ sample_per_iter = 16
155
47
156
48
# callback function to log training progress
157
49
cb (iter, opt_stats, re, θ) = (sample_per_iter= sample_per_iter,ad= adtype)
50
+ # TODO : now using AutoMooncake the example broke, but AutoZygote works, need to debug
158
51
adtype = ADTypes. AutoMooncake (; config = Mooncake. Config ())
159
52
checkconv (iter, stat, re, θ, st) = stat. gradient_norm < one (T)/ 1000
160
53
flow_trained, stats, _ = train_flow (
0 commit comments