Skip to content

Commit 1c4118b

Browse files
committed
minor update of realnnvp constructor and add some doc
1 parent 34a964e commit 1c4118b

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

src/NormalizingFlows.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,7 @@ include("flows/utils.jl")
131131
include("flows/realnvp.jl")
132132
include("flows/neuralspline.jl")
133133

134+
export RealNVP_layer, realnvp
135+
134136

135137
end

src/flows/neuralspline.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ end
3737

3838
@functor NeuralSplineLayer (nn,)
3939

40-
# define forward and inverse transformation
4140
"""
4241
Build a rational quadratic spline (RQS) from the nn output
4342
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline

src/flows/realnvp.jl

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Default constructor of Affine Coupling flow layer
33
44
following the general architecture as Eq(3) in [^AD2025]
55
6-
[^AD2024]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
6+
[^AD2025]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
77
"""
88
struct AffineCoupling <: Bijectors.Bijector
99
dim::Int
@@ -117,10 +117,21 @@ end
117117
# end
118118

119119
"""
120-
Default constructor of RealNVP flow layer
120+
RealNVP_layer(dims, hdims; paramtype = Float64)
121121
122-
single layer of realnvp flow, which is a composition of 2 affine coupling transformations
123-
with complementary masks
122+
Default constructor of single layer of realnvp flow,
123+
which is a composition of 2 affine coupling transformations with complementary masks.
124+
The masking strategy is odd-even masking.
125+
126+
# Arguments
127+
- `dims::Int`: dimension of the problem
128+
- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
129+
130+
# Keyword Arguments
131+
- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
132+
133+
# Returns
134+
- A `Bijectors.Bijector` representing the RealNVP layer.
124135
"""
125136
function RealNVP_layer(
126137
dims::Int, # dimension of problem
@@ -134,25 +145,50 @@ function RealNVP_layer(
134145
# by default use the odd-even masking strategy
135146
af1 = AffineCoupling(dims, hdims, mask_idx1, paramtype)
136147
af2 = AffineCoupling(dims, hdims, mask_idx2, paramtype)
137-
138148
return reduce(, (af1, af2))
139149
end
140150

151+
"""
152+
realnvp(q0, dims, hdims, nlayers; paramtype = Float64)
141153
142-
function RealNVP(
143-
dims::Int, # dimension of problem
154+
Default constructor of RealNVP flow, which is a composition of `nlayers` RealNVP_layer.
155+
# Arguments
156+
- `q0::Distribution{Continuous, Multivariate}`: reference distribution, e.g. `MvNormal(zeros(dims), I)`
157+
- `dims::Int`: dimension of problem
158+
- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
159+
- `nlayers::Int`: number of RealNVP_layer
160+
# Keyword Arguments
161+
- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
162+
163+
# Returns
164+
- A `Bijectors.MultivariateTransformed` representing the RealNVP flow.
165+
166+
"""
167+
168+
function realnvp(
169+
q0::Distribution{Continuous, Multivariate},
144170
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
145171
nlayers::Int; # number of RealNVP_layer
146172
paramtype::Type{T} = Float64, # type of the parameters
147173
) where {T<:AbstractFloat}
148174

149-
q0 = MvNormal(zeros(dims), I) # std Gaussian as the reference distribution
150-
Ls = [RealNVP_layer(dims, hdims; paramtype=paramtype) for _ in 1:nlayers]
151-
175+
dims = length(q0) # dimension of the reference distribution == dim of the problem
176+
Ls = [RealNVP_layer(dims, hdims; paramtype=paramtype) for _ in 1:nlayers]
152177
create_flow(Ls, q0)
153178
end
154179

155-
function RealNVP(dims::Int; paramtype::Type{T} = Float64) where {T<:AbstractFloat}
156-
# default RealNVP with 10 layers, each couplling function has 2 hidden layers with 32 units
157-
return RealNVP(dims, [32, 32], 10; paramtype=paramtype)
158-
end
180+
"""
181+
realnvp(q0; paramtype = Float64)
182+
183+
Default constructor of RealNVP with 10 layers,
184+
each coupling function has 2 hidden layers with 32 units.
185+
Following the general architecture as in [^ASD2020] (see Apdx. E).
186+
187+
188+
[^ASD2020]: Agrawal, A., & Sheldon, D., & Domke, J. (2020).
189+
Advances in Black-Box VI: Normalizing Flows, Importance Weighting, and Optimization.
190+
In *NeurIPS*.
191+
"""
192+
realnvp(q0; paramtype::Type{T} = Float64) where {T<:AbstractFloat} = RealNVP(
193+
q0, [32, 32], 10; paramtype=paramtype
194+
)

0 commit comments

Comments
 (0)