Skip to content

Commit 4caae49

Browse files
committed
add nsf interface
1 parent 81000b9 commit 4caae49

File tree

5 files changed

+115
-39
lines changed

5 files changed

+115
-39
lines changed

example/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
5+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
56
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
67
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1011
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1112
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
13+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1214
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1315
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1416
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
17+
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
1518
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1619
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
1720
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

example/demo_RealNVP.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Bijectors: partition, combine, PartitionMask
55
using Random, Distributions, LinearAlgebra
66
using Functors
77
using Optimisers, ADTypes
8-
using Mooncake
8+
using Mooncake, Zygote
99
using NormalizingFlows
1010

1111
include("SyntheticTargets.jl")
@@ -47,18 +47,16 @@ sample_per_iter = 16
4747

4848
# callback function to log training progress
4949
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
50-
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
51-
adtype = ADTypes.AutoMooncake(; config = nothing)
52-
# adtype = ADTypes.AutoZygote()
50+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
5351

5452
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
5553
flow_trained, stats, _ = train_flow(
5654
rng,
57-
elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
55+
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
5856
flow,
5957
logp,
6058
sample_per_iter;
61-
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
59+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
6260
optimiser=Optimisers.Adam(5e-4),
6361
ADbackend=adtype,
6462
show_progress=true,

example/demo_neural_spline_flow.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,9 @@ logp = Base.Fix1(logpdf, target)
2828
# learn the target using Affine coupling flow
2929
######################################
3030
@leaf MvNormal
31-
q0 = MvNormal(zeros(T, 2), ones(T, 2))
32-
33-
d = 2
34-
hdims = 64
35-
K = 10
36-
B = 30
37-
Ls = [
38-
NeuralSplineLayer(d, hdims, K, B, [1]) NeuralSplineLayer(d, hdims, K, B, [2]) for
39-
i in 1:3
40-
]
41-
42-
flow = create_flow(Ls, q0)
31+
q0 = MvNormal(zeros(T, 2), I)
32+
33+
flow = nsf(q0; paramtype=Float32)
4334
flow_untrained = deepcopy(flow)
4435

4536

@@ -50,7 +41,6 @@ sample_per_iter = 64
5041

5142
# callback function to log training progress
5243
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
53-
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
5444
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
5545
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
5646
flow_trained, stats, _ = train_flow(
@@ -73,3 +63,35 @@ losses = map(x -> x.loss, stats)
7363
######################################
7464
plot(losses; label="Loss", linewidth=2) # plot the loss
7565
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)
66+
67+
68+
69+
70+
71+
72+
73+
74+
75+
# using MonotonicSplines, Plots, InverseFunctions, ChangesOfVariables
76+
77+
# f = rand(RQSpline)
78+
# f.pX, f.pY, f.dYdX
79+
80+
# plot(f, xlims = (-6, 6)); plot!(inverse(f), xlims = (-6, 6))
81+
82+
# x = 1.2
83+
# y = f(x)
84+
# with_logabsdet_jacobian(f, x)
85+
# inverse(f)(y)
86+
# with_logabsdet_jacobian(inverse(f), y)
87+
88+
89+
90+
# # test auto grad
91+
# function loss(x)
92+
# y, laj = MonotonicSplines.rqs_forward(x, f.pX, f.pY, f.dYdX)
93+
# return laj + 0.5 * sum((y .- 1).^2)
94+
# end
95+
96+
# xx = rand()
97+
# val, g = DifferentiationInterface.value_and_gradient(loss, adtype, xx)

src/NormalizingFlows.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ include("flows/realnvp.jl")
132132
include("flows/neuralspline.jl")
133133

134134
export create_flow
135-
export RealNVP_layer, realnvp, AffineCoupling
136-
export NeuralSplineLayer
135+
export AffineCoupling, RealNVP_layer, realnvp
136+
export NeuralSplineCoupling, NSF_layer, nsf
137137

138138

139139
end

src/flows/neuralspline.jl

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,48 @@
1-
##################################
2-
# define neural spline layer using Bijectors.jl interface
3-
#################################
41
"""
52
Neural Rational quadratic Spline layer
63
74
# References
85
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
96
"""
10-
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
7+
struct NeuralSplineCoupling{T,A<:Flux.Chain} <: Bijectors.Bijector
118
dim::Int # dimension of input
129
K::Int # number of knots
1310
n_dims_transferred::Int # number of dimensions that are transformed
14-
nn::A # networks that parmaterize the knots and derivatives
1511
B::T # bound of the knots
12+
nn::A # networks that parmaterize the knots and derivatives
1613
mask::Bijectors.PartitionMask
1714
end
1815

19-
function NeuralSplineLayer(
16+
function NeuralSplineCoupling(
2017
dim::T1, # dimension of input
21-
hdims::T1, # dimension of hidden units for s and t
18+
hdims::AbstractVector{T1}, # dimension of hidden units for s and t
2219
K::T1, # number of knots
2320
B::T2, # bound of the knots
2421
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
25-
) where {T1<:Int,T2<:Real}
22+
paramtype::Type{T3}, # type of the parameters, e.g., Float64 or Float32
23+
) where {T1<:Int,T2<:Real,T3<:AbstractFloat}
2624
num_of_transformed_dims = length(mask_idx)
2725
input_dims = dim - num_of_transformed_dims
2826

2927
# output dim of the NN
3028
output_dims = (3K - 1)*num_of_transformed_dims
3129
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
32-
nn = mlp3(input_dims, hdims, output_dims)
30+
# todo: ensure type stability
31+
nn = fnn(input_dims, hdims, output_dims; output_activation=nothing, paramtype=paramtype)
3332

3433
mask = Bijectors.PartitionMask(dim, mask_idx)
35-
return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask)
34+
return NeuralSplineCoupling(dim, K, num_of_transformed_dims, B, nn, mask)
3635
end
3736

38-
@functor NeuralSplineLayer (nn,)
37+
@functor NeuralSplineCoupling (nn,)
3938

4039
"""
4140
Build a rational quadratic spline (RQS) from the nn output
4241
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
4342
4443
we just need to map the nn output to the knots and derivatives of the RQS
4544
"""
46-
function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
45+
function instantiate_rqs(nsl::NeuralSplineCoupling, x::AbstractVector)
4746
K, B = nsl.K, nsl.B
4847
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
4948
ws = @view nnoutput[:, 1:K]
@@ -52,46 +51,100 @@ function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
5251
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
5352
end
5453

55-
function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector)
54+
function Bijectors.transform(nsl::NeuralSplineCoupling, x::AbstractVector)
5655
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
5756
# instantiate rqs knots and derivatives
5857
rqs = instantiate_rqs(nsl, x_2)
5958
y_1 = Bijectors.transform(rqs, x_1)
6059
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
6160
end
6261

63-
function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
62+
function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector)
6463
nsl = insl.orig
6564
y1, y2, y3 = partition(nsl.mask, y)
6665
rqs = instantiate_rqs(nsl, y2)
6766
x1 = Bijectors.transform(Inverse(rqs), y1)
6867
return Bijectors.combine(nsl.mask, x1, y2, y3)
6968
end
7069

71-
function (nsl::NeuralSplineLayer)(x::AbstractVector)
70+
function (nsl::NeuralSplineCoupling)(x::AbstractVector)
7271
return Bijectors.transform(nsl, x)
7372
end
7473

7574
# define logabsdetjac
76-
function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector)
75+
function Bijectors.logabsdetjac(nsl::NeuralSplineCoupling, x::AbstractVector)
7776
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
7877
rqs = instantiate_rqs(nsl, x_2)
7978
logjac = logabsdetjac(rqs, x_1)
8079
return logjac
8180
end
8281

83-
function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
82+
function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector)
8483
nsl = insl.orig
8584
y1, y2, _ = partition(nsl.mask, y)
8685
rqs = instantiate_rqs(nsl, y2)
8786
logjac = logabsdetjac(Inverse(rqs), y1)
8887
return logjac
8988
end
9089

91-
function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector)
90+
function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineCoupling, x::AbstractVector)
9291
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
9392
rqs = instantiate_rqs(nsl, x_2)
9493
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
9594
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
9695
end
9796

97+
98+
"""
99+
NSF_layer(dims, hdims; paramtype = Float64)
100+
101+
Default constructor of single layer of Neural Spline Flow (NSF)
102+
which is a composition of 2 neural spline coupling transformations with complementary masks.
103+
The masking strategy is odd-even masking.
104+
105+
# Arguments
106+
- `dims::Int`: dimension of the problem
107+
- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
108+
- `K::Int`: number of knots
109+
- `B::AbstractFloat`: bound of the knots
110+
111+
# Keyword Arguments
112+
- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
113+
114+
# Returns
115+
- A `Bijectors.Bijector` representing the NSF layer.
116+
"""
117+
function NSF_layer(
118+
dims::T1, # dimension of problem
119+
hdims::AbstractVector{T1}, # dimension of hidden units for nn
120+
K::T1, # number of knots
121+
B::T2; # bound of the knots
122+
paramtype::Type{T2} = Float64, # type of the parameters
123+
) where {T1<:Int,T2<:AbstractFloat}
124+
125+
mask_idx1 = 1:2:dims
126+
mask_idx2 = 2:2:dims
127+
128+
# by default use the odd-even masking strategy
129+
nsf1 = NeuralSplineCoupling(dims, hdims, K, B, mask_idx1, paramtype)
130+
nsf2 = NeuralSplineCoupling(dims, hdims, K, B, mask_idx2, paramtype)
131+
return reduce(, (nsf1, nsf2))
132+
end
133+
134+
function nsf(
135+
q0::Distribution{Multivariate,Continuous},
136+
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
137+
K::Int,
138+
B::T,
139+
nlayers::Int; # number of RealNVP_layer
140+
paramtype::Type{T} = Float64, # type of the parameters
141+
) where {T<:AbstractFloat}
142+
143+
dims = length(q0) # dimension of the reference distribution == dim of the problem
144+
Ls = [NSF_layer(dims, hdims, K, B; paramtype=paramtype) for _ in 1:nlayers]
145+
create_flow(Ls, q0)
146+
end
147+
148+
nsf(q0; paramtype::Type{T} = Float64) where {T<:AbstractFloat} = nsf(
149+
q0, [32, 32], 10, 30*one(T), 10; paramtype=paramtype
150+
)

0 commit comments

Comments
 (0)