Skip to content

Commit 2903b83

Browse files
committed
dump the previous nsf implementation
1 parent 1f30b33 commit 2903b83

File tree

6 files changed

+54
-244
lines changed

6 files changed

+54
-244
lines changed

.github/workflows/Examples.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,5 @@ jobs:
3838
include("demo_RealNVP.jl");
3939
@info "Running neural spline flow demo";
4040
include("demo_neural_spline_flow.jl");
41-
@info "Running new neural spline flow demo";
42-
include("demo_new_nsf.jl");
4341
@info "Running Hamiltonian flow demo";
4442
include("demo_hamiltonian_flow.jl");'

example/demo_neural_spline_flow.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Bijectors: partition, combine, PartitionMask
44
using Random, Distributions, LinearAlgebra
55
using Functors
66
using Optimisers, ADTypes
7-
using Mooncake
7+
using Zygote
88
using NormalizingFlows
99

1010
include("SyntheticTargets.jl")
@@ -20,9 +20,9 @@ T = Float32
2020
######################################
2121
# a difficult banana target
2222
######################################
23-
2423
target = Banana(2, one(T), 100one(T))
2524
logp = Base.Fix1(logpdf, target)
25+
2626
######################################
2727
# learn the target using Neural Spline Flow
2828
######################################
@@ -39,14 +39,16 @@ sample_per_iter = 64
3939

4040
# callback function to log training progress
4141
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
42-
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
42+
# TODO: mooncake has some issues with kernelabstractions?
43+
# adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
44+
adtype = ADTypes.AutoZygote()
4345
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
4446
flow_trained, stats, _ = train_flow(
45-
elbo,
47+
elbo_batch,
4648
flow,
4749
logp,
4850
sample_per_iter;
49-
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
51+
max_iters=10000, # change to larger number of iterations (e.g., 50_000) for better results
5052
optimiser=Optimisers.Adam(1e-4),
5153
ADbackend=adtype,
5254
show_progress=true,

example/demo_new_nsf.jl

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/NormalizingFlows.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,10 @@ end
130130
include("flows/utils.jl")
131131
include("flows/realnvp.jl")
132132
include("flows/neuralspline.jl")
133-
# a new implementation of Neural Spline Flow based on MonotonicSplines.jl
134-
# the construction of the RQS seems to be more efficient than the one in Bijectors.jl
135-
# and supports batched operations.
136-
include("flows/new_nsf.jl")
137133

138134
export create_flow
139135
export AffineCoupling, RealNVP_layer, realnvp
140136
export NeuralSplineCoupling, NSF_layer, nsf
141-
export NSC, new_NSF_layer, new_nsf
142137

143138

144139
end

src/flows/neuralspline.jl

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
using MonotonicSplines
2+
# a new implementation of Neural Spline Flow based on MonotonicSplines.jl
3+
# the construction of the RQS seems to be more efficient than the one in Bijectors.jl
4+
# and supports batched operations.
5+
16
"""
27
Neural Rational quadratic Spline layer
3-
48
# References
59
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
610
"""
@@ -18,9 +22,9 @@ function NeuralSplineCoupling(
1822
hdims::AbstractVector{T1}, # dimension of hidden units for s and t
1923
K::T1, # number of knots
2024
B::T2, # bound of the knots
21-
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
22-
paramtype::Type{T3}, # type of the parameters, e.g., Float64 or Float32
23-
) where {T1<:Int,T2<:Real,T3<:AbstractFloat}
25+
mask_idx::AbstractVector{T1}, # index of dimensione that one wants to apply transformations on
26+
paramtype::Type{T2}, # type of the parameters, e.g., Float64 or Float32
27+
) where {T1<:Int,T2<:AbstractFloat}
2428
num_of_transformed_dims = length(mask_idx)
2529
input_dims = dim - num_of_transformed_dims
2630

@@ -30,86 +34,75 @@ function NeuralSplineCoupling(
3034
nn = fnn(input_dims, hdims, output_dims; output_activation=nothing, paramtype=paramtype)
3135

3236
mask = Bijectors.PartitionMask(dim, mask_idx)
33-
return NeuralSplineCoupling(dim, K, num_of_transformed_dims, B, nn, mask)
37+
return NeuralSplineCoupling{T2, typeof(nn)}(dim, K, num_of_transformed_dims, B, nn, mask)
3438
end
3539

3640
@functor NeuralSplineCoupling (nn,)
3741

38-
"""
39-
Build a rational quadratic spline (RQS) from the nn output
40-
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
41-
42-
we just need to map the nn output to the knots and derivatives of the RQS
43-
"""
44-
function instantiate_rqs(nsl::NeuralSplineCoupling, x::AbstractVector)
45-
K, B = nsl.K, nsl.B
46-
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
47-
ws = @view nnoutput[:, 1:K]
48-
hs = @view nnoutput[:, (K + 1):(2K)]
49-
ds = @view nnoutput[:, (2K + 1):(3K - 1)]
50-
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
42+
function get_nsc_params(nsc::NeuralSplineCoupling, x::AbstractVecOrMat)
43+
nnoutput = nsc.nn(x)
44+
px, py, dydx = MonotonicSplines.rqs_params_from_nn(nnoutput, nsc.n_dims_transferred, nsc.B)
45+
return px, py, dydx
5146
end
5247

53-
function Bijectors.transform(nsl::NeuralSplineCoupling, x::AbstractVector)
54-
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
55-
# instantiate rqs knots and derivatives
56-
rqs = instantiate_rqs(nsl, x_2)
57-
y_1 = Bijectors.transform(rqs, x_1)
58-
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
59-
end
48+
# when input x is a vector instead of a matrix
49+
# need this to transform it to a matrix with one row
50+
# otherwise, rqs_forward and rqs_inverse will throw an error
51+
_ensure_matrix(x) = x isa AbstractVector ? reshape(x, 1, length(x)) : x
6052

61-
function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector)
62-
nsl = insl.orig
63-
y1, y2, y3 = partition(nsl.mask, y)
64-
rqs = instantiate_rqs(nsl, y2)
65-
x1 = Bijectors.transform(Inverse(rqs), y1)
66-
return Bijectors.combine(nsl.mask, x1, y2, y3)
53+
function Bijectors.transform(nsc::NeuralSplineCoupling, x::AbstractVecOrMat)
54+
x1, x2, x3 = Bijectors.partition(nsc.mask, x)
55+
# instantiate rqs knots and derivatives
56+
px, py, dydx = get_nsc_params(nsc, x2)
57+
x1 = _ensure_matrix(x1)
58+
y1, _ = MonotonicSplines.rqs_forward(x1, px, py, dydx)
59+
return Bijectors.combine(nsc.mask, y1, x2, x3)
6760
end
6861

69-
function (nsl::NeuralSplineCoupling)(x::AbstractVector)
70-
return Bijectors.transform(nsl, x)
62+
function Bijectors.with_logabsdet_jacobian(nsc::NeuralSplineCoupling, x::AbstractVecOrMat)
63+
x1, x2, x3 = Bijectors.partition(nsc.mask, x)
64+
# instantiate rqs knots and derivatives
65+
px, py, dydx = get_nsc_params(nsc, x2)
66+
x1 = _ensure_matrix(x1)
67+
y1, logjac = MonotonicSplines.rqs_forward(x1, px, py, dydx)
68+
return Bijectors.combine(nsc.mask, y1, x2, x3), logjac isa Real ? logjac : vec(logjac)
7169
end
7270

73-
# define logabsdetjac
74-
function Bijectors.logabsdetjac(nsl::NeuralSplineCoupling, x::AbstractVector)
75-
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
76-
rqs = instantiate_rqs(nsl, x_2)
77-
logjac = logabsdetjac(rqs, x_1)
78-
return logjac
71+
function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVecOrMat)
72+
nsc = insl.orig
73+
y1, y2, y3 = partition(nsc.mask, y)
74+
px, py, dydx = get_nsc_params(nsc, y2)
75+
y1 = _ensure_matrix(y1)
76+
x1, _ = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
77+
return Bijectors.combine(nsc.mask, x1, y2, y3)
7978
end
8079

81-
function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector)
82-
nsl = insl.orig
83-
y1, y2, _ = partition(nsl.mask, y)
84-
rqs = instantiate_rqs(nsl, y2)
85-
logjac = logabsdetjac(Inverse(rqs), y1)
86-
return logjac
80+
function Bijectors.with_logabsdet_jacobian(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVecOrMat)
81+
nsc = insl.orig
82+
y1, y2, y3 = partition(nsc.mask, y)
83+
px, py, dydx = get_nsc_params(nsc, y2)
84+
y1 = _ensure_matrix(y1)
85+
x1, logjac = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
86+
return Bijectors.combine(nsc.mask, x1, y2, y3), logjac isa Real ? logjac : vec(logjac)
8787
end
8888

89-
function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineCoupling, x::AbstractVector)
90-
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
91-
rqs = instantiate_rqs(nsl, x_2)
92-
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
93-
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
89+
function (nsc::NeuralSplineCoupling)(x::AbstractVecOrMat)
90+
return Bijectors.transform(nsc, x)
9491
end
9592

9693

9794
"""
9895
NSF_layer(dims, hdims; paramtype = Float64)
99-
10096
Default constructor of single layer of Neural Spline Flow (NSF)
10197
which is a composition of 2 neural spline coupling transformations with complementary masks.
10298
The masking strategy is odd-even masking.
103-
10499
# Arguments
105100
- `dims::Int`: dimension of the problem
106101
- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
107102
- `K::Int`: number of knots
108103
- `B::AbstractFloat`: bound of the knots
109-
110104
# Keyword Arguments
111105
- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
112-
113106
# Returns
114107
- A `Bijectors.Bijector` representing the NSF layer.
115108
"""

src/flows/new_nsf.jl

Lines changed: 0 additions & 113 deletions
This file was deleted.

0 commit comments

Comments
 (0)