Skip to content

Commit e03213c

Browse files
committed
update realnvp default constructor
1 parent cd63565 commit e03213c

File tree

4 files changed

+99
-30
lines changed

4 files changed

+99
-30
lines changed

src/NormalizingFlows.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module NormalizingFlows
22

33
using ADTypes
4-
using Bijectors
54
using Distributions
65
using LinearAlgebra
76
using Optimisers
87
using ProgressMeter
98
using Random
109
using StatsBase
10+
using Bijectors
11+
using Bijectors: PartitionMask, Inverse, combine, partition
1112
using Functors
1213
import DifferentiationInterface as DI
1314

src/flows/neuralspline.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ end
3939

4040
# define forward and inverse transformation
4141
"""
42-
Build a rational quadratic spline from the nn output
42+
Build a rational quadratic spline (RQS) from the nn output
4343
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
4444
4545
we just need to map the nn output to the knots and derivatives of the RQS

src/flows/realnvp.jl

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
##################################
2-
# define affine coupling layer using Bijectors.jl interface
3-
#################################
1+
"""
2+
Default constructor of Affine Coupling flow layer
3+
4+
following the general architecture as Eq(3) in [^AD2025]
5+
6+
[^AD2024]: Agrawal, J., & Domke, J. (2025). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *AISTATS*
7+
"""
48
struct AffineCoupling <: Bijectors.Bijector
59
dim::Int
610
mask::Bijectors.PartitionMask
@@ -12,21 +16,25 @@ end
1216
@functor AffineCoupling (s, t)
1317

1418
function AffineCoupling(
15-
dim::Int, # dimension of input
16-
hdims::Int, # dimension of hidden units for s and t
17-
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
18-
)
19-
cdims = length(mask_idx) # dimension of parts used to construct coupling law
20-
s = mlp3(cdims, hdims, cdims)
21-
t = mlp3(cdims, hdims, cdims)
19+
dim::Int, # dimension of the problem
20+
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
21+
mask_idx::AbstractVector{Int}, # index of dimensione that one wants to apply transformations on
22+
paramtype::Type{T}
23+
) where {T<:AbstractFloat}
24+
cdims = length(mask_idx) # dimension of parts used to construct coupling law
25+
# for the scaling network s, add tanh to the output to ensure stability during training
26+
s = fnn(dim-cdims, hdims, cdims; output_activation=Flux.tanh, paramtype=paramtype)
27+
# no transfomration for the output of the translation network t
28+
t = fnn(dim-cdims, hdims, cdims; output_activation=nothing, paramtype=paramtype)
2229
mask = PartitionMask(dim, mask_idx)
2330
return AffineCoupling(dim, mask, s, t)
2431
end
2532

2633
function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
2734
# partition vector using 'af.mask::PartitionMask`
2835
x₁, x₂, x₃ = partition(af.mask, x)
29-
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
36+
s_x₂ = af.s(x₂)
37+
y₁ = x₁ .* exp.(s_x₂) .+ af.t(x₂)
3038
return combine(af.mask, y₁, x₂, x₃)
3139
end
3240

@@ -36,15 +44,17 @@ end
3644

3745
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
3846
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
39-
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
40-
logjac = sum(log abs, af.s(x_2)) # this is a scalar
47+
s_x2 = af.s(x_2)
48+
y_1 = exp.(s_x2) .* x_1 .+ af.t(x_2)
49+
logjac = sum(s_x2) # this is a scalar
4150
return combine(af.mask, y_1, x_2, x_3), logjac
4251
end
4352

4453
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
4554
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
46-
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
47-
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
55+
s_x2 = af.s(x_2)
56+
y_1 = exp.(s_x2) .* x_1 .+ af.t(x_2)
57+
logjac = sum(s_x2; dims=1) # 1 × size(x, 2)
4858
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
4959
end
5060

@@ -56,8 +66,9 @@ function Bijectors.with_logabsdet_jacobian(
5666
# partition vector using `af.mask::PartitionMask`
5767
y_1, y_2, y_3 = partition(af.mask, y)
5868
# inverse transformation
59-
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
60-
logjac = -sum(log abs, af.s(y_2))
69+
s_y2 = af.s(y_2)
70+
x_1 = (y_1 .- af.t(y_2)) .* exp.(-s_y2)
71+
logjac = -sum(s_y2)
6172
return combine(af.mask, x_1, y_2, y_3), logjac
6273
end
6374

@@ -68,8 +79,9 @@ function Bijectors.with_logabsdet_jacobian(
6879
# partition vector using `af.mask::PartitionMask`
6980
y_1, y_2, y_3 = partition(af.mask, y)
7081
# inverse transformation
71-
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
72-
logjac = -sum(log abs, af.s(y_2); dims = 1)
82+
s_y2 = af.s(y_2)
83+
x_1 = (y_1 .- af.t(y_2)) .* exp.(-s_y2)
84+
logjac = -sum(s_y2; dims=1)
7385
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
7486
end
7587

@@ -104,3 +116,43 @@ end
104116
# return AffineCoupling(dim, mask, s, t)
105117
# end
106118

119+
"""
120+
Default constructor of RealNVP flow layer
121+
122+
single layer of realnvp flow, which is a composition of 2 affine coupling transformations
123+
with complementary masks
124+
"""
125+
function RealNVP_layer(
126+
dims::Int, # dimension of problem
127+
hdims::AbstractVector{Int}; # dimension of hidden units for s and t
128+
paramtype::Type{T} = Float64, # type of the parameters
129+
) where {T<:AbstractFloat}
130+
131+
mask_idx1 = 1:2:dims
132+
mask_idx2 = 2:2:dims
133+
134+
# by default use the odd-even masking strategy
135+
af1 = AffineCoupling(dims, hdims, mask_idx1, paramtype)
136+
af2 = AffineCoupling(dims, hdims, mask_idx2, paramtype)
137+
138+
return reduce(, (af1, af2))
139+
end
140+
141+
142+
function RealNVP(
143+
dims::Int, # dimension of problem
144+
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
145+
nlayers::Int; # number of RealNVP_layer
146+
paramtype::Type{T} = Float64, # type of the parameters
147+
) where {T<:AbstractFloat}
148+
149+
q0 = MvNormal(zeros(dims), I) # std Gaussian as the reference distribution
150+
Ls = [RealNVP_layer(dims, hdims; paramtype=paramtype) for _ in 1:nlayers]
151+
152+
create_flow(Ls, q0)
153+
end
154+
155+
function RealNVP(dims:Int; paramtype::Type{T} = Float64) where {T<:AbstractFloat}
156+
# default RealNVP with 10 layers, each couplling function has 2 hidden layers with 32 units
157+
return RealNVP(dims, [32, 32], 10; paramtype=paramtype)
158+
end

src/flows/utils.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,29 @@ using Flux
66
77
A simple wrapper for a 3 layer dense MLP
88
"""
9-
function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu)
10-
return Chain(
9+
function mlp3(
10+
input_dim::Int,
11+
hidden_dims::Int,
12+
output_dim::Int;
13+
activation=Flux.leakyrelu,
14+
paramtype::Type{T} = Float64
15+
) where {T<:AbstractFloat}
16+
m = Chain(
1117
Flux.Dense(input_dim, hidden_dims, activation),
1218
Flux.Dense(hidden_dims, hidden_dims, activation),
1319
Flux.Dense(hidden_dims, output_dim),
1420
)
21+
return Flux._paramtype(paramtype, m)
1522
end
1623

1724
"""
1825
fnn(
1926
input_dim::Int,
20-
hidden_dims::AbstractVector{<:Int},
27+
hidden_dims::AbstractVector{Int},
2128
output_dim::Int;
2229
inlayer_activation=Flux.leakyrelu,
23-
output_activation=Flux.tanh,
30+
output_activation=nothing,
31+
paramtype::Type{T} = Float64,
2432
)
2533
2634
Create a fully connected neural network (FNN).
@@ -31,17 +39,19 @@ Create a fully connected neural network (FNN).
3139
- `output_dim::Int`: The dimension of the output layer.
3240
- `inlayer_activation`: The activation function for the hidden layers. Defaults to `Flux.leakyrelu`.
3341
- `output_activation`: The activation function for the output layer. Defaults to `Flux.tanh`.
42+
- `paramtype::Type{T} = Float64`: The type of the parameters in the network, defaults to `Float64`.
3443
3544
# Returns
3645
- A `Flux.Chain` representing the FNN.
3746
"""
3847
function fnn(
3948
input_dim::Int,
40-
hidden_dims::AbstractVector{<:Int},
49+
hidden_dims::AbstractVector{Int},
4150
output_dim::Int;
4251
inlayer_activation=Flux.leakyrelu,
43-
output_activation=Flux.tanh,
44-
)
52+
output_activation=nothing,
53+
paramtype::Type{T} = Float64,
54+
) where {T<:AbstractFloat}
4555
# Create a chain of dense layers
4656
# First layer
4757
layers = Any[Flux.Dense(input_dim, hidden_dims[1], inlayer_activation)]
@@ -55,8 +65,14 @@ function fnn(
5565
end
5666

5767
# Output layer
58-
push!(layers, Flux.Dense(hidden_dims[end], output_dim, output_activation))
59-
return Chain(layers...)
68+
if output_activation === nothing
69+
push!(layers, Flux.Dense(hidden_dims[end], output_dim))
70+
else
71+
push!(layers, Flux.Dense(hidden_dims[end], output_dim, output_activation))
72+
end
73+
74+
m = Chain(layers...)
75+
return Flux._paramtype(paramtype, m)
6076
end
6177

6278
function create_flow(Ls, q₀)

0 commit comments

Comments
 (0)