1- # #################################
2- # define neural spline layer using Bijectors.jl interface
3- # ################################
41"""
52Neural Rational quadratic Spline layer
63
74# References
85[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
96"""
10- struct NeuralSplineLayer {T,A<: Flux.Chain } <: Bijectors.Bijector
7+ struct NeuralSplineCoupling {T,A<: Flux.Chain } <: Bijectors.Bijector
118 dim:: Int # dimension of input
129 K:: Int # number of knots
1310 n_dims_transferred:: Int # number of dimensions that are transformed
14- nn:: A # networks that parmaterize the knots and derivatives
1511 B:: T # bound of the knots
12+ nn:: A # networks that parmaterize the knots and derivatives
1613 mask:: Bijectors.PartitionMask
1714end
1815
19- function NeuralSplineLayer (
16+ function NeuralSplineCoupling (
2017 dim:: T1 , # dimension of input
21- hdims:: T1 , # dimension of hidden units for s and t
18+ hdims:: AbstractVector{T1} , # dimension of hidden units for s and t
2219 K:: T1 , # number of knots
2320 B:: T2 , # bound of the knots
2421 mask_idx:: AbstractVector{<:Int} , # index of dimensione that one wants to apply transformations on
25- ) where {T1<: Int ,T2<: Real }
22+ paramtype:: Type{T3} , # type of the parameters, e.g., Float64 or Float32
23+ ) where {T1<: Int ,T2<: Real ,T3<: AbstractFloat }
2624 num_of_transformed_dims = length (mask_idx)
2725 input_dims = dim - num_of_transformed_dims
2826
2927 # output dim of the NN
3028 output_dims = (3 K - 1 )* num_of_transformed_dims
3129 # one big mlp that outputs all the knots and derivatives for all the transformed dimensions
32- nn = mlp3 (input_dims, hdims, output_dims)
30+ # todo: ensure type stability
31+ nn = fnn (input_dims, hdims, output_dims; output_activation= nothing , paramtype= paramtype)
3332
3433 mask = Bijectors. PartitionMask (dim, mask_idx)
35- return NeuralSplineLayer (dim, K, num_of_transformed_dims, nn, B , mask)
34+ return NeuralSplineCoupling (dim, K, num_of_transformed_dims, B, nn , mask)
3635end
3736
38- @functor NeuralSplineLayer (nn,)
37+ @functor NeuralSplineCoupling (nn,)
3938
4039"""
4140Build a rational quadratic spline (RQS) from the nn output
4241Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
4342
4443we just need to map the nn output to the knots and derivatives of the RQS
4544"""
46- function instantiate_rqs (nsl:: NeuralSplineLayer , x:: AbstractVector )
45+ function instantiate_rqs (nsl:: NeuralSplineCoupling , x:: AbstractVector )
4746 K, B = nsl. K, nsl. B
4847 nnoutput = reshape (nsl. nn (x), nsl. n_dims_transferred, :)
4948 ws = @view nnoutput[:, 1 : K]
@@ -52,46 +51,100 @@ function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
5251 return Bijectors. RationalQuadraticSpline (ws, hs, ds, B)
5352end
5453
55- function Bijectors. transform (nsl:: NeuralSplineLayer , x:: AbstractVector )
54+ function Bijectors. transform (nsl:: NeuralSplineCoupling , x:: AbstractVector )
5655 x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
5756 # instantiate rqs knots and derivatives
5857 rqs = instantiate_rqs (nsl, x_2)
5958 y_1 = Bijectors. transform (rqs, x_1)
6059 return Bijectors. combine (nsl. mask, y_1, x_2, x_3)
6160end
6261
63- function Bijectors. transform (insl:: Inverse{<:NeuralSplineLayer } , y:: AbstractVector )
62+ function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling } , y:: AbstractVector )
6463 nsl = insl. orig
6564 y1, y2, y3 = partition (nsl. mask, y)
6665 rqs = instantiate_rqs (nsl, y2)
6766 x1 = Bijectors. transform (Inverse (rqs), y1)
6867 return Bijectors. combine (nsl. mask, x1, y2, y3)
6968end
7069
71- function (nsl:: NeuralSplineLayer )(x:: AbstractVector )
70+ function (nsl:: NeuralSplineCoupling )(x:: AbstractVector )
7271 return Bijectors. transform (nsl, x)
7372end
7473
7574# define logabsdetjac
76- function Bijectors. logabsdetjac (nsl:: NeuralSplineLayer , x:: AbstractVector )
75+ function Bijectors. logabsdetjac (nsl:: NeuralSplineCoupling , x:: AbstractVector )
7776 x_1, x_2, _ = Bijectors. partition (nsl. mask, x)
7877 rqs = instantiate_rqs (nsl, x_2)
7978 logjac = logabsdetjac (rqs, x_1)
8079 return logjac
8180end
8281
83- function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineLayer } , y:: AbstractVector )
82+ function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineCoupling } , y:: AbstractVector )
8483 nsl = insl. orig
8584 y1, y2, _ = partition (nsl. mask, y)
8685 rqs = instantiate_rqs (nsl, y2)
8786 logjac = logabsdetjac (Inverse (rqs), y1)
8887 return logjac
8988end
9089
91- function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineLayer , x:: AbstractVector )
90+ function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineCoupling , x:: AbstractVector )
9291 x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
9392 rqs = instantiate_rqs (nsl, x_2)
9493 y_1, logjac = with_logabsdet_jacobian (rqs, x_1)
9594 return Bijectors. combine (nsl. mask, y_1, x_2, x_3), logjac
9695end
9796
97+
98+ """
99+ NSF_layer(dims, hdims; paramtype = Float64)
100+
101+ Default constructor of single layer of Neural Spline Flow (NSF)
102+ which is a composition of 2 neural spline coupling transformations with complementary masks.
103+ The masking strategy is odd-even masking.
104+
105+ # Arguments
106+ - `dims::Int`: dimension of the problem
107+ - `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
108+ - `K::Int`: number of knots
109+ - `B::AbstractFloat`: bound of the knots
110+
111+ # Keyword Arguments
112+ - `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
113+
114+ # Returns
115+ - A `Bijectors.Bijector` representing the NSF layer.
116+ """
117+ function NSF_layer (
118+ dims:: T1 , # dimension of problem
119+ hdims:: AbstractVector{T1} , # dimension of hidden units for nn
120+ K:: T1 , # number of knots
121+ B:: T2 ; # bound of the knots
122+ paramtype:: Type{T2} = Float64, # type of the parameters
123+ ) where {T1<: Int ,T2<: AbstractFloat }
124+
125+ mask_idx1 = 1 : 2 : dims
126+ mask_idx2 = 2 : 2 : dims
127+
128+ # by default use the odd-even masking strategy
129+ nsf1 = NeuralSplineCoupling (dims, hdims, K, B, mask_idx1, paramtype)
130+ nsf2 = NeuralSplineCoupling (dims, hdims, K, B, mask_idx2, paramtype)
131+ return reduce (∘ , (nsf1, nsf2))
132+ end
133+
134+ function nsf (
135+ q0:: Distribution{Multivariate,Continuous} ,
136+ hdims:: AbstractVector{Int} , # dimension of hidden units for s and t
137+ K:: Int ,
138+ B:: T ,
139+ nlayers:: Int ; # number of RealNVP_layer
140+ paramtype:: Type{T} = Float64, # type of the parameters
141+ ) where {T<: AbstractFloat }
142+
143+ dims = length (q0) # dimension of the reference distribution == dim of the problem
144+ Ls = [NSF_layer (dims, hdims, K, B; paramtype= paramtype) for _ in 1 : nlayers]
145+ create_flow (Ls, q0)
146+ end
147+
148+ nsf (q0; paramtype:: Type{T} = Float64) where {T<: AbstractFloat } = nsf (
149+ q0, [32 , 32 ], 10 , 30 * one (T), 10 ; paramtype= paramtype
150+ )
0 commit comments