1
- # #################################
2
- # define affine coupling layer using Bijectors.jl interface
3
- # ################################
1
+ """
2
+ Default constructor of Affine Coupling flow layer
3
+
4
+ following the general architecture as Eq(3) in [^AD2025]
5
+
6
+ [^AD2024]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
7
+ """
4
8
struct AffineCoupling <: Bijectors.Bijector
5
9
dim:: Int
6
10
mask:: Bijectors.PartitionMask
12
16
@functor AffineCoupling (s, t)
13
17
14
18
function AffineCoupling (
15
- dim:: Int , # dimension of input
16
- hdims:: Int , # dimension of hidden units for s and t
17
- mask_idx:: AbstractVector , # index of dimensione that one wants to apply transformations on
18
- )
19
- cdims = length (mask_idx) # dimension of parts used to construct coupling law
20
- s = mlp3 (cdims, hdims, cdims)
21
- t = mlp3 (cdims, hdims, cdims)
19
+ dim:: Int , # dimension of the problem
20
+ hdims:: AbstractVector{Int} , # dimension of hidden units for s and t
21
+ mask_idx:: AbstractVector{Int} , # index of dimensione that one wants to apply transformations on
22
+ paramtype:: Type{T}
23
+ ) where {T<: AbstractFloat }
24
+ cdims = length (mask_idx) # dimension of parts used to construct coupling law
25
+ # for the scaling network s, add tanh to the output to ensure stability during training
26
+ s = fnn (dim- cdims, hdims, cdims; output_activation= Flux. tanh, paramtype= paramtype)
27
+ # no transfomration for the output of the translation network t
28
+ t = fnn (dim- cdims, hdims, cdims; output_activation= nothing , paramtype= paramtype)
22
29
mask = PartitionMask (dim, mask_idx)
23
30
return AffineCoupling (dim, mask, s, t)
24
31
end
25
32
26
33
function Bijectors. transform (af:: AffineCoupling , x:: AbstractVecOrMat )
27
34
# partition vector using 'af.mask::PartitionMask`
28
35
x₁, x₂, x₃ = partition (af. mask, x)
29
- y₁ = x₁ .* af. s (x₂) .+ af. t (x₂)
36
+ s_x₂ = af. s (x₂)
37
+ y₁ = x₁ .* exp .(s_x₂) .+ af. t (x₂)
30
38
return combine (af. mask, y₁, x₂, x₃)
31
39
end
32
40
36
44
37
45
function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractVector )
38
46
x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
39
- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
40
- logjac = sum (log ∘ abs, af. s (x_2)) # this is a scalar
47
+ s_x2 = af. s (x_2)
48
+ y_1 = exp .(s_x2) .* x_1 .+ af. t (x_2)
49
+ logjac = sum (s_x2) # this is a scalar
41
50
return combine (af. mask, y_1, x_2, x_3), logjac
42
51
end
43
52
44
53
function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractMatrix )
45
54
x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
46
- y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
47
- logjac = sum (log ∘ abs, af. s (x_2); dims = 1 ) # 1 × size(x, 2)
55
+ s_x2 = af. s (x_2)
56
+ y_1 = exp .(s_x2) .* x_1 .+ af. t (x_2)
57
+ logjac = sum (s_x2; dims= 1 ) # 1 × size(x, 2)
48
58
return combine (af. mask, y_1, x_2, x_3), vec (logjac)
49
59
end
50
60
@@ -56,8 +66,9 @@ function Bijectors.with_logabsdet_jacobian(
56
66
# partition vector using `af.mask::PartitionMask`
57
67
y_1, y_2, y_3 = partition (af. mask, y)
58
68
# inverse transformation
59
- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
60
- logjac = - sum (log ∘ abs, af. s (y_2))
69
+ s_y2 = af. s (y_2)
70
+ x_1 = (y_1 .- af. t (y_2)) .* exp .(- s_y2)
71
+ logjac = - sum (s_y2)
61
72
return combine (af. mask, x_1, y_2, y_3), logjac
62
73
end
63
74
@@ -68,8 +79,9 @@ function Bijectors.with_logabsdet_jacobian(
68
79
# partition vector using `af.mask::PartitionMask`
69
80
y_1, y_2, y_3 = partition (af. mask, y)
70
81
# inverse transformation
71
- x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
72
- logjac = - sum (log ∘ abs, af. s (y_2); dims = 1 )
82
+ s_y2 = af. s (y_2)
83
+ x_1 = (y_1 .- af. t (y_2)) .* exp .(- s_y2)
84
+ logjac = - sum (s_y2; dims= 1 )
73
85
return combine (af. mask, x_1, y_2, y_3), vec (logjac)
74
86
end
75
87
104
116
# return AffineCoupling(dim, mask, s, t)
105
117
# end
106
118
119
+ """
120
+ Default constructor of RealNVP flow layer
121
+
122
+ single layer of realnvp flow, which is a composition of 2 affine coupling transformations
123
+ with complementary masks
124
+ """
125
+ function RealNVP_layer (
126
+ dims:: Int , # dimension of problem
127
+ hdims:: AbstractVector{Int} ; # dimension of hidden units for s and t
128
+ paramtype:: Type{T} = Float64, # type of the parameters
129
+ ) where {T<: AbstractFloat }
130
+
131
+ mask_idx1 = 1 : 2 : dims
132
+ mask_idx2 = 2 : 2 : dims
133
+
134
+ # by default use the odd-even masking strategy
135
+ af1 = AffineCoupling (dims, hdims, mask_idx1, paramtype)
136
+ af2 = AffineCoupling (dims, hdims, mask_idx2, paramtype)
137
+
138
+ return reduce (∘ , (af1, af2))
139
+ end
140
+
141
+
142
+ function RealNVP (
143
+ dims:: Int , # dimension of problem
144
+ hdims:: AbstractVector{Int} , # dimension of hidden units for s and t
145
+ nlayers:: Int ; # number of RealNVP_layer
146
+ paramtype:: Type{T} = Float64, # type of the parameters
147
+ ) where {T<: AbstractFloat }
148
+
149
+ q0 = MvNormal (zeros (dims), I) # std Gaussian as the reference distribution
150
+ Ls = [RealNVP_layer (dims, hdims; paramtype= paramtype) for _ in 1 : nlayers]
151
+
152
+ create_flow (Ls, q0)
153
+ end
154
+
155
+ function RealNVP (dims: Int; paramtype:: Type{T} = Float64) where {T<: AbstractFloat }
156
+ # default RealNVP with 10 layers, each couplling function has 2 hidden layers with 32 units
157
+ return RealNVP (dims, [32 , 32 ], 10 ; paramtype= paramtype)
158
+ end
0 commit comments