Skip to content

Commit c4128fa

Browse files
committed
add new nsf implementation and demo; much faster than the original nsf
1 parent 99a0fed commit c4128fa

File tree

6 files changed

+195
-42
lines changed

6 files changed

+195
-42
lines changed

.github/workflows/Examples.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,7 @@ jobs:
3838
include("demo_RealNVP.jl");
3939
@info "Running neural spline flow demo";
4040
include("demo_neural_spline_flow.jl");
41+
@info "Running new neural spline flow demo";
42+
include("demo_new_nsf.jl");
4143
@info "Running Hamiltonian flow demo";
4244
include("demo_hamiltonian_flow.jl");'

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
1415
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1516
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -31,6 +32,7 @@ Distributions = "0.25"
3132
DocStringExtensions = "0.9"
3233
Flux = "0.16"
3334
Functors = "0.5.2"
35+
MonotonicSplines = "0.3.3"
3436
Optimisers = "0.2.16, 0.3, 0.4"
3537
ProgressMeter = "1.0.0"
3638
StatsBase = "0.33, 0.34"

example/demo_neural_spline_flow.jl

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Flux
21
using Bijectors
32
using Bijectors: partition, combine, PartitionMask
43

@@ -19,21 +18,20 @@ rng = Random.default_rng()
1918
T = Float32
2019

2120
######################################
22-
# neals funnel target
21+
# a difficult banana target
2322
######################################
24-
target = Funnel(2, 0.0f0, 9.0f0)
25-
logp = Base.Fix1(logpdf, target)
2623

24+
target = Banana(2, one(T), 100one(T))
25+
logp = Base.Fix1(logpdf, target)
2726
######################################
28-
# learn the target using Affine coupling flow
27+
# learn the target using Neural Spline Flow
2928
######################################
3029
@leaf MvNormal
3130
q0 = MvNormal(zeros(T, 2), I)
3231

33-
flow = nsf(q0; paramtype=Float32)
34-
flow_untrained = deepcopy(flow)
35-
3632

33+
flow = nsf(q0; paramtype=T)
34+
flow_untrained = deepcopy(flow)
3735
######################################
3836
# start training
3937
######################################
@@ -48,8 +46,8 @@ flow_trained, stats, _ = train_flow(
4846
flow,
4947
logp,
5048
sample_per_iter;
51-
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
52-
optimiser=Optimisers.Adam(5e-5),
49+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
50+
optimiser=Optimisers.Adam(1e-4),
5351
ADbackend=adtype,
5452
show_progress=true,
5553
callback=cb,
@@ -63,35 +61,3 @@ losses = map(x -> x.loss, stats)
6361
######################################
6462
plot(losses; label="Loss", linewidth=2) # plot the loss
6563
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)
66-
67-
68-
69-
70-
71-
72-
73-
74-
75-
# using MonotonicSplines, Plots, InverseFunctions, ChangesOfVariables
76-
77-
# f = rand(RQSpline)
78-
# f.pX, f.pY, f.dYdX
79-
80-
# plot(f, xlims = (-6, 6)); plot!(inverse(f), xlims = (-6, 6))
81-
82-
# x = 1.2
83-
# y = f(x)
84-
# with_logabsdet_jacobian(f, x)
85-
# inverse(f)(y)
86-
# with_logabsdet_jacobian(inverse(f), y)
87-
88-
89-
90-
# # test auto grad
91-
# function loss(x)
92-
# y, laj = MonotonicSplines.rqs_forward(x, f.pX, f.pY, f.dYdX)
93-
# return laj + 0.5 * sum((y .- 1).^2)
94-
# end
95-
96-
# xx = rand()
97-
# val, g = DifferentiationInterface.value_and_gradient(loss, adtype, xx)

example/demo_new_nsf.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Bijectors
2+
using Bijectors: partition, combine, PartitionMask
3+
4+
using Random, Distributions, LinearAlgebra
5+
using Functors
6+
using Optimisers, ADTypes
7+
using Mooncake
8+
using NormalizingFlows
9+
10+
include("SyntheticTargets.jl")
11+
include("utils.jl")
12+
13+
##################################
14+
# start demo
15+
#################################
16+
Random.seed!(123)
17+
rng = Random.default_rng()
18+
T = Float32
19+
20+
######################################
21+
# a difficult banana target
22+
######################################
23+
target = Banana(2, one(T), 100one(T))
24+
logp = Base.Fix1(logpdf, target)
25+
26+
######################################
27+
# learn the target using Neural Spline Flow
28+
######################################
29+
@leaf MvNormal
30+
q0 = MvNormal(zeros(T, 2), I)
31+
32+
33+
flow = new_nsf(q0; paramtype=T)
34+
flow_untrained = deepcopy(flow)
35+
######################################
36+
# start training
37+
######################################
38+
sample_per_iter = 64
39+
40+
# callback function to log training progress
41+
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
42+
# TODO: mooncake has some issues with kernelabstractions?
43+
# adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
44+
adtype = ADTypes.AutoZygote()
45+
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
46+
flow_trained, stats, _ = train_flow(
47+
elbo_batch,
48+
flow,
49+
logp,
50+
sample_per_iter;
51+
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
52+
optimiser=Optimisers.Adam(1e-4),
53+
ADbackend=adtype,
54+
show_progress=true,
55+
callback=cb,
56+
hasconverged=checkconv,
57+
)
58+
θ, re = Optimisers.destructure(flow_trained)
59+
losses = map(x -> x.loss, stats)
60+
61+
######################################
62+
# evaluate trained flow
63+
######################################
64+
plot(losses; label="Loss", linewidth=2) # plot the loss
65+
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)

src/NormalizingFlows.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,15 @@ end
130130
include("flows/utils.jl")
131131
include("flows/realnvp.jl")
132132
include("flows/neuralspline.jl")
133+
# a new implementation of Neural Spline Flow based on MonotonicSplines.jl
134+
# the construction of the RQS seems to be more efficient than the one in Bijectors.jl
135+
# and supports batched operations.
136+
include("flows/new_nsf.jl")
133137

134138
export create_flow
135139
export AffineCoupling, RealNVP_layer, realnvp
136140
export NeuralSplineCoupling, NSF_layer, nsf
141+
export NSC, new_NSF_layer, new_nsf
137142

138143

139144
end

src/flows/new_nsf.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
using MonotonicSplines
2+
3+
struct NSC{T,A<:Flux.Chain} <: Bijectors.Bijector
4+
dim::Int # dimension of input
5+
K::Int # number of knots
6+
n_dims_transferred::Int # number of dimensions that are transformed
7+
B::T # bound of the knots
8+
nn::A # networks that parmaterize the knots and derivatives
9+
mask::Bijectors.PartitionMask
10+
end
11+
12+
function NSC(
13+
dim::T1, # dimension of input
14+
hdims::AbstractVector{T1}, # dimension of hidden units for s and t
15+
K::T1, # number of knots
16+
B::T2, # bound of the knots
17+
mask_idx::AbstractVector{T1}, # index of dimensione that one wants to apply transformations on
18+
paramtype::Type{T2}, # type of the parameters, e.g., Float64 or Float32
19+
) where {T1<:Int,T2<:AbstractFloat}
20+
num_of_transformed_dims = length(mask_idx)
21+
input_dims = dim - num_of_transformed_dims
22+
23+
# output dim of the NN
24+
output_dims = (3K - 1)*num_of_transformed_dims
25+
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
26+
nn = fnn(input_dims, hdims, output_dims; output_activation=nothing, paramtype=paramtype)
27+
28+
mask = Bijectors.PartitionMask(dim, mask_idx)
29+
return NSC{T2, typeof(nn)}(dim, K, num_of_transformed_dims, B, nn, mask)
30+
end
31+
32+
@functor NSC (nn,)
33+
34+
function get_nsl_params(nsl::NSC, x::AbstractVecOrMat)
35+
nnoutput = nsl.nn(x)
36+
px, py, dydx = MonotonicSplines.rqs_params_from_nn(nnoutput, nsl.n_dims_transferred, nsl.B)
37+
return px, py, dydx
38+
end
39+
40+
function Bijectors.transform(nsl::NSC, x::AbstractVecOrMat)
41+
x1, x2, x3 = Bijectors.partition(nsl.mask, x)
42+
# instantiate rqs knots and derivatives
43+
px, py, dydx = get_nsl_params(nsl, x2)
44+
if x1 isa AbstractVector
45+
x1 = reshape(x1, 1, length(x1)) # ensure x1 is a matrix
46+
end
47+
y1, _ = MonotonicSplines.rqs_forward(x1, px, py, dydx)
48+
return Bijectors.combine(nsl.mask, y1, x2, x3)
49+
end
50+
51+
function Bijectors.with_logabsdet_jacobian(nsl::NSC, x::AbstractVecOrMat)
52+
x1, x2, x3 = Bijectors.partition(nsl.mask, x)
53+
# instantiate rqs knots and derivatives
54+
px, py, dydx = get_nsl_params(nsl, x2)
55+
y1, logjac = MonotonicSplines.rqs_forward(x1, px, py, dydx)
56+
return Bijectors.combine(nsl.mask, y1, x2, x3), vec(logjac)
57+
end
58+
59+
function Bijectors.transform(insl::Inverse{<:NSC}, y::AbstractVecOrMat)
60+
nsl = insl.orig
61+
y1, y2, y3 = partition(nsl.mask, y)
62+
px, py, dydx = get_nsl_params(nsl, y2)
63+
x1, _ = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
64+
return Bijectors.combine(nsl.mask, x1, y2, y3)
65+
end
66+
67+
function Bijectors.with_logabsdet_jacobian(insl::Inverse{<:NSC}, y::AbstractVecOrMat)
68+
nsl = insl.orig
69+
y1, y2, y3 = partition(nsl.mask, y)
70+
px, py, dydx = get_nsl_params(nsl, y2)
71+
x1, logjac = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
72+
return Bijectors.combine(nsl.mask, x1, y2, y3), logjac isa Real ? logjac : vec(logjac)
73+
end
74+
75+
function (nsl::NSC)(x::AbstractVecOrMat)
76+
return Bijectors.transform(nsl, x)
77+
end
78+
79+
80+
function new_NSF_layer(
81+
dims::T1, # dimension of problem
82+
hdims::AbstractVector{T1}, # dimension of hidden units for nn
83+
K::T1, # number of knots
84+
B::T2; # bound of the knots
85+
paramtype::Type{T2} = Float64, # type of the parameters
86+
) where {T1<:Int,T2<:AbstractFloat}
87+
88+
mask_idx1 = 1:2:dims
89+
mask_idx2 = 2:2:dims
90+
91+
# by default use the odd-even masking strategy
92+
nsf1 = NSC(dims, hdims, K, B, mask_idx1, paramtype)
93+
nsf2 = NSC(dims, hdims, K, B, mask_idx2, paramtype)
94+
return reduce(, (nsf1, nsf2))
95+
end
96+
97+
function new_nsf(
98+
q0::Distribution{Multivariate,Continuous},
99+
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
100+
K::Int,
101+
B::T,
102+
nlayers::Int; # number of RealNVP_layer
103+
paramtype::Type{T} = Float64, # type of the parameters
104+
) where {T<:AbstractFloat}
105+
106+
dims = length(q0) # dimension of the reference distribution == dim of the problem
107+
Ls = [new_NSF_layer(dims, hdims, K, B; paramtype=paramtype) for _ in 1:nlayers]
108+
create_flow(Ls, q0)
109+
end
110+
111+
new_nsf(q0; paramtype::Type{T} = Float64) where {T<:AbstractFloat} = new_nsf(
112+
q0, [32, 32], 10, 30*one(T), 10; paramtype=paramtype
113+
)

0 commit comments

Comments
 (0)