1+ using MonotonicSplines
2+ # a new implementation of Neural Spline Flow based on MonotonicSplines.jl
3+ # the construction of the RQS seems to be more efficient than the one in Bijectors.jl
4+ # and supports batched operations.
5+
16"""
27Neural Rational quadratic Spline layer
3-
48# References
59[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
610"""
@@ -18,9 +22,9 @@ function NeuralSplineCoupling(
1822 hdims:: AbstractVector{T1} , # dimension of hidden units for s and t
1923 K:: T1 , # number of knots
2024 B:: T2 , # bound of the knots
21- mask_idx:: AbstractVector{<:Int } , # index of dimensione that one wants to apply transformations on
22- paramtype:: Type{T3 } , # type of the parameters, e.g., Float64 or Float32
23- ) where {T1<: Int ,T2<: Real ,T3 <: AbstractFloat }
25+ mask_idx:: AbstractVector{T1 } , # index of dimensione that one wants to apply transformations on
26+ paramtype:: Type{T2 } , # type of the parameters, e.g., Float64 or Float32
27+ ) where {T1<: Int ,T2<: AbstractFloat }
2428 num_of_transformed_dims = length (mask_idx)
2529 input_dims = dim - num_of_transformed_dims
2630
@@ -30,86 +34,75 @@ function NeuralSplineCoupling(
3034 nn = fnn (input_dims, hdims, output_dims; output_activation= nothing , paramtype= paramtype)
3135
3236 mask = Bijectors. PartitionMask (dim, mask_idx)
33- return NeuralSplineCoupling (dim, K, num_of_transformed_dims, B, nn, mask)
37+ return NeuralSplineCoupling {T2, typeof(nn)} (dim, K, num_of_transformed_dims, B, nn, mask)
3438end
3539
3640@functor NeuralSplineCoupling (nn,)
3741
38- """
39- Build a rational quadratic spline (RQS) from the nn output
40- Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
41-
42- we just need to map the nn output to the knots and derivatives of the RQS
43- """
44- function instantiate_rqs (nsl:: NeuralSplineCoupling , x:: AbstractVector )
45- K, B = nsl. K, nsl. B
46- nnoutput = reshape (nsl. nn (x), nsl. n_dims_transferred, :)
47- ws = @view nnoutput[:, 1 : K]
48- hs = @view nnoutput[:, (K + 1 ): (2 K)]
49- ds = @view nnoutput[:, (2 K + 1 ): (3 K - 1 )]
50- return Bijectors. RationalQuadraticSpline (ws, hs, ds, B)
42+ function get_nsc_params (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
43+ nnoutput = nsc. nn (x)
44+ px, py, dydx = MonotonicSplines. rqs_params_from_nn (nnoutput, nsc. n_dims_transferred, nsc. B)
45+ return px, py, dydx
5146end
5247
53- function Bijectors. transform (nsl:: NeuralSplineCoupling , x:: AbstractVector )
54- x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
55- # instantiate rqs knots and derivatives
56- rqs = instantiate_rqs (nsl, x_2)
57- y_1 = Bijectors. transform (rqs, x_1)
58- return Bijectors. combine (nsl. mask, y_1, x_2, x_3)
59- end
48+ # when input x is a vector instead of a matrix
49+ # need this to transform it to a matrix with one row
50+ # otherwise, rqs_forward and rqs_inverse will throw an error
51+ _ensure_matrix (x) = x isa AbstractVector ? reshape (x, 1 , length (x)) : x
6052
61- function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVector )
62- nsl = insl. orig
63- y1, y2, y3 = partition (nsl. mask, y)
64- rqs = instantiate_rqs (nsl, y2)
65- x1 = Bijectors. transform (Inverse (rqs), y1)
66- return Bijectors. combine (nsl. mask, x1, y2, y3)
53+ function Bijectors. transform (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
54+ x1, x2, x3 = Bijectors. partition (nsc. mask, x)
55+ # instantiate rqs knots and derivatives
56+ px, py, dydx = get_nsc_params (nsc, x2)
57+ x1 = _ensure_matrix (x1)
58+ y1, _ = MonotonicSplines. rqs_forward (x1, px, py, dydx)
59+ return Bijectors. combine (nsc. mask, y1, x2, x3)
6760end
6861
69- function (nsl:: NeuralSplineCoupling )(x:: AbstractVector )
70- return Bijectors. transform (nsl, x)
62+ function Bijectors. with_logabsdet_jacobian (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
63+ x1, x2, x3 = Bijectors. partition (nsc. mask, x)
64+ # instantiate rqs knots and derivatives
65+ px, py, dydx = get_nsc_params (nsc, x2)
66+ x1 = _ensure_matrix (x1)
67+ y1, logjac = MonotonicSplines. rqs_forward (x1, px, py, dydx)
68+ return Bijectors. combine (nsc. mask, y1, x2, x3), logjac isa Real ? logjac : vec (logjac)
7169end
7270
73- # define logabsdetjac
74- function Bijectors. logabsdetjac (nsl:: NeuralSplineCoupling , x:: AbstractVector )
75- x_1, x_2, _ = Bijectors. partition (nsl. mask, x)
76- rqs = instantiate_rqs (nsl, x_2)
77- logjac = logabsdetjac (rqs, x_1)
78- return logjac
71+ function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVecOrMat )
72+ nsc = insl. orig
73+ y1, y2, y3 = partition (nsc. mask, y)
74+ px, py, dydx = get_nsc_params (nsc, y2)
75+ y1 = _ensure_matrix (y1)
76+ x1, _ = MonotonicSplines. rqs_inverse (y1, px, py, dydx)
77+ return Bijectors. combine (nsc. mask, x1, y2, y3)
7978end
8079
81- function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVector )
82- nsl = insl. orig
83- y1, y2, _ = partition (nsl. mask, y)
84- rqs = instantiate_rqs (nsl, y2)
85- logjac = logabsdetjac (Inverse (rqs), y1)
86- return logjac
80+ function Bijectors. with_logabsdet_jacobian (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVecOrMat )
81+ nsc = insl. orig
82+ y1, y2, y3 = partition (nsc. mask, y)
83+ px, py, dydx = get_nsc_params (nsc, y2)
84+ y1 = _ensure_matrix (y1)
85+ x1, logjac = MonotonicSplines. rqs_inverse (y1, px, py, dydx)
86+ return Bijectors. combine (nsc. mask, x1, y2, y3), logjac isa Real ? logjac : vec (logjac)
8787end
8888
89- function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineCoupling , x:: AbstractVector )
90- x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
91- rqs = instantiate_rqs (nsl, x_2)
92- y_1, logjac = with_logabsdet_jacobian (rqs, x_1)
93- return Bijectors. combine (nsl. mask, y_1, x_2, x_3), logjac
89+ function (nsc:: NeuralSplineCoupling )(x:: AbstractVecOrMat )
90+ return Bijectors. transform (nsc, x)
9491end
9592
9693
9794"""
9895 NSF_layer(dims, hdims; paramtype = Float64)
99-
10096Default constructor of single layer of Neural Spline Flow (NSF)
10197which is a composition of 2 neural spline coupling transformations with complementary masks.
10298The masking strategy is odd-even masking.
103-
10499# Arguments
105100- `dims::Int`: dimension of the problem
106101- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
107102- `K::Int`: number of knots
108103- `B::AbstractFloat`: bound of the knots
109-
110104# Keyword Arguments
111105- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
112-
113106# Returns
114107- A `Bijectors.Bijector` representing the NSF layer.
115108"""
0 commit comments