@@ -3,7 +3,7 @@ Default constructor of Affine Coupling flow layer
3
3
4
4
following the general architecture as Eq(3) in [^AD2025]
5
5
6
- [^AD2024 ]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
6
+ [^AD2025 ]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
7
7
"""
8
8
struct AffineCoupling <: Bijectors.Bijector
9
9
dim:: Int
@@ -117,10 +117,21 @@ end
117
117
# end
118
118
119
119
"""
120
- Default constructor of RealNVP flow layer
120
+ RealNVP_layer(dims, hdims; paramtype = Float64)
121
121
122
- single layer of realnvp flow, which is a composition of 2 affine coupling transformations
123
- with complementary masks
122
+ Default constructor of single layer of realnvp flow,
123
+ which is a composition of 2 affine coupling transformations with complementary masks.
124
+ The masking strategy is odd-even masking.
125
+
126
+ # Arguments
127
+ - `dims::Int`: dimension of the problem
128
+ - `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
129
+
130
+ # Keyword Arguments
131
+ - `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
132
+
133
+ # Returns
134
+ - A `Bijectors.Bijector` representing the RealNVP layer.
124
135
"""
125
136
function RealNVP_layer (
126
137
dims:: Int , # dimension of problem
@@ -134,25 +145,50 @@ function RealNVP_layer(
134
145
# by default use the odd-even masking strategy
135
146
af1 = AffineCoupling (dims, hdims, mask_idx1, paramtype)
136
147
af2 = AffineCoupling (dims, hdims, mask_idx2, paramtype)
137
-
138
148
return reduce (∘ , (af1, af2))
139
149
end
140
150
151
+ """
152
+ realnvp(q0, dims, hdims, nlayers; paramtype = Float64)
141
153
142
- function RealNVP (
143
- dims:: Int , # dimension of problem
154
+ Default constructor of RealNVP flow, which is a composition of `nlayers` RealNVP_layer.
155
+ # Arguments
156
+ - `q0::Distribution{Continuous, Multivariate}`: reference distribution, e.g. `MvNormal(zeros(dims), I)`
157
+ - `dims::Int`: dimension of problem
158
+ - `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
159
+ - `nlayers::Int`: number of RealNVP_layer
160
+ # Keyword Arguments
161
+ - `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
162
+
163
+ # Returns
164
+ - A `Bijectors.MultivariateTransformed` representing the RealNVP flow.
165
+
166
+ """
167
+
168
+ function realnvp (
169
+ q0:: Distribution{Continuous, Multivariate} ,
144
170
hdims:: AbstractVector{Int} , # dimension of hidden units for s and t
145
171
nlayers:: Int ; # number of RealNVP_layer
146
172
paramtype:: Type{T} = Float64, # type of the parameters
147
173
) where {T<: AbstractFloat }
148
174
149
- q0 = MvNormal (zeros (dims), I) # std Gaussian as the reference distribution
150
- Ls = [RealNVP_layer (dims, hdims; paramtype= paramtype) for _ in 1 : nlayers]
151
-
175
+ dims = length (q0) # dimension of the reference distribution == dim of the problem
176
+ Ls = [RealNVP_layer (dims, hdims; paramtype= paramtype) for _ in 1 : nlayers]
152
177
create_flow (Ls, q0)
153
178
end
154
179
155
- function RealNVP (dims:: Int ; paramtype:: Type{T} = Float64) where {T<: AbstractFloat }
156
- # default RealNVP with 10 layers, each couplling function has 2 hidden layers with 32 units
157
- return RealNVP (dims, [32 , 32 ], 10 ; paramtype= paramtype)
158
- end
180
+ """
181
+ realnvp(q0; paramtype = Float64)
182
+
183
+ Default constructor of RealNVP with 10 layers,
184
+ each coupling function has 2 hidden layers with 32 units.
185
+ Following the general architecture as in [^ASD2020] (see Apdx. E).
186
+
187
+
188
+ [^ASD2020]: Agrawal, A., & Sheldon, D., & Domke, J. (2020).
189
+ Advances in Black-Box VI: Normalizing Flows, Importance Weighting, and Optimization.
190
+ In *NeurIPS*.
191
+ """
192
+ realnvp (q0; paramtype:: Type{T} = Float64) where {T<: AbstractFloat } = RealNVP (
193
+ q0, [32 , 32 ], 10 ; paramtype= paramtype
194
+ )
0 commit comments