Skip to content

Commit 8db828f

Browse files
committed
make realnvp and nsf layers as part of the pkg
1 parent e0eb8ab commit 8db828f

File tree

5 files changed

+233
-0
lines changed

5 files changed

+233
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
12+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1113
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1214
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1315
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
@@ -27,6 +29,8 @@ CUDA = "5"
2729
DifferentiationInterface = "0.6, 0.7"
2830
Distributions = "0.25"
2931
DocStringExtensions = "0.9"
32+
Flux = "0.16"
33+
Functors = "0.5.2"
3034
Optimisers = "0.2.16, 0.3, 0.4"
3135
ProgressMeter = "1.0.0"
3236
StatsBase = "0.33, 0.34"

src/NormalizingFlows.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,11 @@ function _device_specific_rand(
123123
return Random.rand(rng, td, n)
124124
end
125125

126+
127+
# interface of contructing common flow layers
128+
include("flows/utils.jl")
129+
include("flows/realnvp.jl")
130+
include("flows/neuralspline.jl")
131+
132+
126133
end

src/flows/neuralspline.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
##################################
2+
# define neural spline layer using Bijectors.jl interface
3+
#################################
4+
"""
5+
Neural Rational quadratic Spline layer
6+
7+
# References
8+
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
9+
"""
10+
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
11+
dim::Int # dimension of input
12+
K::Int # number of knots
13+
n_dims_transferred::Int # number of dimensions that are transformed
14+
nn::A # networks that parmaterize the knots and derivatives
15+
B::T # bound of the knots
16+
mask::Bijectors.PartitionMask
17+
end
18+
19+
function NeuralSplineLayer(
20+
dim::T1, # dimension of input
21+
hdims::T1, # dimension of hidden units for s and t
22+
K::T1, # number of knots
23+
B::T2, # bound of the knots
24+
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
25+
) where {T1<:Int,T2<:Real}
26+
num_of_transformed_dims = length(mask_idx)
27+
input_dims = dim - num_of_transformed_dims
28+
29+
# output dim of the NN
30+
output_dims = (3K - 1)*num_of_transformed_dims
31+
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
32+
nn = mlp3(input_dims, hdims, output_dims)
33+
34+
mask = Bijectors.PartitionMask(dim, mask_idx)
35+
return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask)
36+
end
37+
38+
@functor NeuralSplineLayer (nn,)
39+
40+
# define forward and inverse transformation
41+
"""
42+
Build a rational quadratic spline from the nn output
43+
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
44+
45+
we just need to map the nn output to the knots and derivatives of the RQS
46+
"""
47+
function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
48+
K, B = nsl.K, nsl.B
49+
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
50+
ws = @view nnoutput[:, 1:K]
51+
hs = @view nnoutput[:, (K + 1):(2K)]
52+
ds = @view nnoutput[:, (2K + 1):(3K - 1)]
53+
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
54+
end
55+
56+
function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector)
57+
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
58+
# instantiate rqs knots and derivatives
59+
rqs = instantiate_rqs(nsl, x_2)
60+
y_1 = Bijectors.transform(rqs, x_1)
61+
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
62+
end
63+
64+
function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
65+
nsl = insl.orig
66+
y1, y2, y3 = partition(nsl.mask, y)
67+
rqs = instantiate_rqs(nsl, y2)
68+
x1 = Bijectors.transform(Inverse(rqs), y1)
69+
return Bijectors.combine(nsl.mask, x1, y2, y3)
70+
end
71+
72+
function (nsl::NeuralSplineLayer)(x::AbstractVector)
73+
return Bijectors.transform(nsl, x)
74+
end
75+
76+
# define logabsdetjac
77+
function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector)
78+
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
79+
rqs = instantiate_rqs(nsl, x_2)
80+
logjac = logabsdetjac(rqs, x_1)
81+
return logjac
82+
end
83+
84+
function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
85+
nsl = insl.orig
86+
y1, y2, _ = partition(nsl.mask, y)
87+
rqs = instantiate_rqs(nsl, y2)
88+
logjac = logabsdetjac(Inverse(rqs), y1)
89+
return logjac
90+
end
91+
92+
function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector)
93+
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
94+
rqs = instantiate_rqs(nsl, x_2)
95+
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
96+
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
97+
end
98+

src/flows/realnvp.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
##################################
2+
# define affine coupling layer using Bijectors.jl interface
3+
#################################
4+
struct AffineCoupling <: Bijectors.Bijector
5+
dim::Int
6+
mask::Bijectors.PartitionMask
7+
s::Flux.Chain
8+
t::Flux.Chain
9+
end
10+
11+
# let params track field s and t
12+
@functor AffineCoupling (s, t)
13+
14+
function AffineCoupling(
15+
dim::Int, # dimension of input
16+
hdims::Int, # dimension of hidden units for s and t
17+
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
18+
)
19+
cdims = length(mask_idx) # dimension of parts used to construct coupling law
20+
s = mlp3(cdims, hdims, cdims)
21+
t = mlp3(cdims, hdims, cdims)
22+
mask = PartitionMask(dim, mask_idx)
23+
return AffineCoupling(dim, mask, s, t)
24+
end
25+
26+
function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
27+
# partition vector using 'af.mask::PartitionMask`
28+
x₁, x₂, x₃ = partition(af.mask, x)
29+
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
30+
return combine(af.mask, y₁, x₂, x₃)
31+
end
32+
33+
function (af::AffineCoupling)(x::AbstractArray)
34+
return transform(af, x)
35+
end
36+
37+
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
38+
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
39+
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
40+
logjac = sum(log abs, af.s(x_2)) # this is a scalar
41+
return combine(af.mask, y_1, x_2, x_3), logjac
42+
end
43+
44+
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
45+
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
46+
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
47+
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
48+
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
49+
end
50+
51+
52+
function Bijectors.with_logabsdet_jacobian(
53+
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
54+
)
55+
af = iaf.orig
56+
# partition vector using `af.mask::PartitionMask`
57+
y_1, y_2, y_3 = partition(af.mask, y)
58+
# inverse transformation
59+
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
60+
logjac = -sum(log abs, af.s(y_2))
61+
return combine(af.mask, x_1, y_2, y_3), logjac
62+
end
63+
64+
function Bijectors.with_logabsdet_jacobian(
65+
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
66+
)
67+
af = iaf.orig
68+
# partition vector using `af.mask::PartitionMask`
69+
y_1, y_2, y_3 = partition(af.mask, y)
70+
# inverse transformation
71+
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
72+
logjac = -sum(log abs, af.s(y_2); dims = 1)
73+
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
74+
end
75+
76+
###################
77+
# an equivalent definition of AffineCoupling using Bijectors.Coupling
78+
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
79+
###################
80+
81+
# struct AffineCoupling <: Bijectors.Bijector
82+
# dim::Int
83+
# mask::Bijectors.PartitionMask
84+
# s::Flux.Chain
85+
# t::Flux.Chain
86+
# end
87+
88+
# # let params track field s and t
89+
# @functor AffineCoupling (s, t)
90+
91+
# function AffineCoupling(dim, mask, s, t)
92+
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
93+
# end
94+
95+
# function AffineCoupling(
96+
# dim::Int, # dimension of input
97+
# hdims::Int, # dimension of hidden units for s and t
98+
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
99+
# )
100+
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
101+
# s = mlp3(cdims, hdims, cdims)
102+
# t = mlp3(cdims, hdims, cdims)
103+
# mask = PartitionMask(dim, mask_idx)
104+
# return AffineCoupling(dim, mask, s, t)
105+
# end
106+

src/flows/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using Bijectors: transformed
2+
using Flux
3+
4+
"""
5+
A simple wrapper for a 3 layer dense MLP
6+
"""
7+
function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu)
8+
return Chain(
9+
Flux.Dense(input_dim, hidden_dims, activation),
10+
Flux.Dense(hidden_dims, hidden_dims, activation),
11+
Flux.Dense(hidden_dims, output_dim),
12+
)
13+
end
14+
15+
function create_flow(Ls, q₀)
16+
ts = reduce(, Ls)
17+
return transformed(q₀, ts)
18+
end

0 commit comments

Comments
 (0)