1
+ using MonotonicSplines
2
+ # a new implementation of Neural Spline Flow based on MonotonicSplines.jl
3
+ # the construction of the RQS seems to be more efficient than the one in Bijectors.jl
4
+ # and supports batched operations.
5
+
1
6
"""
2
7
Neural Rational quadratic Spline layer
3
-
4
8
# References
5
9
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
6
10
"""
@@ -18,9 +22,9 @@ function NeuralSplineCoupling(
18
22
hdims:: AbstractVector{T1} , # dimension of hidden units for s and t
19
23
K:: T1 , # number of knots
20
24
B:: T2 , # bound of the knots
21
- mask_idx:: AbstractVector{<:Int } , # index of dimensione that one wants to apply transformations on
22
- paramtype:: Type{T3 } , # type of the parameters, e.g., Float64 or Float32
23
- ) where {T1<: Int ,T2<: Real ,T3 <: AbstractFloat }
25
+ mask_idx:: AbstractVector{T1 } , # index of dimensione that one wants to apply transformations on
26
+ paramtype:: Type{T2 } , # type of the parameters, e.g., Float64 or Float32
27
+ ) where {T1<: Int ,T2<: AbstractFloat }
24
28
num_of_transformed_dims = length (mask_idx)
25
29
input_dims = dim - num_of_transformed_dims
26
30
@@ -30,86 +34,75 @@ function NeuralSplineCoupling(
30
34
nn = fnn (input_dims, hdims, output_dims; output_activation= nothing , paramtype= paramtype)
31
35
32
36
mask = Bijectors. PartitionMask (dim, mask_idx)
33
- return NeuralSplineCoupling (dim, K, num_of_transformed_dims, B, nn, mask)
37
+ return NeuralSplineCoupling {T2, typeof(nn)} (dim, K, num_of_transformed_dims, B, nn, mask)
34
38
end
35
39
36
40
@functor NeuralSplineCoupling (nn,)
37
41
38
- """
39
- Build a rational quadratic spline (RQS) from the nn output
40
- Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
41
-
42
- we just need to map the nn output to the knots and derivatives of the RQS
43
- """
44
- function instantiate_rqs (nsl:: NeuralSplineCoupling , x:: AbstractVector )
45
- K, B = nsl. K, nsl. B
46
- nnoutput = reshape (nsl. nn (x), nsl. n_dims_transferred, :)
47
- ws = @view nnoutput[:, 1 : K]
48
- hs = @view nnoutput[:, (K + 1 ): (2 K)]
49
- ds = @view nnoutput[:, (2 K + 1 ): (3 K - 1 )]
50
- return Bijectors. RationalQuadraticSpline (ws, hs, ds, B)
42
+ function get_nsc_params (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
43
+ nnoutput = nsc. nn (x)
44
+ px, py, dydx = MonotonicSplines. rqs_params_from_nn (nnoutput, nsc. n_dims_transferred, nsc. B)
45
+ return px, py, dydx
51
46
end
52
47
53
- function Bijectors. transform (nsl:: NeuralSplineCoupling , x:: AbstractVector )
54
- x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
55
- # instantiate rqs knots and derivatives
56
- rqs = instantiate_rqs (nsl, x_2)
57
- y_1 = Bijectors. transform (rqs, x_1)
58
- return Bijectors. combine (nsl. mask, y_1, x_2, x_3)
59
- end
48
+ # when input x is a vector instead of a matrix
49
+ # need this to transform it to a matrix with one row
50
+ # otherwise, rqs_forward and rqs_inverse will throw an error
51
+ _ensure_matrix (x) = x isa AbstractVector ? reshape (x, 1 , length (x)) : x
60
52
61
- function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVector )
62
- nsl = insl. orig
63
- y1, y2, y3 = partition (nsl. mask, y)
64
- rqs = instantiate_rqs (nsl, y2)
65
- x1 = Bijectors. transform (Inverse (rqs), y1)
66
- return Bijectors. combine (nsl. mask, x1, y2, y3)
53
+ function Bijectors. transform (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
54
+ x1, x2, x3 = Bijectors. partition (nsc. mask, x)
55
+ # instantiate rqs knots and derivatives
56
+ px, py, dydx = get_nsc_params (nsc, x2)
57
+ x1 = _ensure_matrix (x1)
58
+ y1, _ = MonotonicSplines. rqs_forward (x1, px, py, dydx)
59
+ return Bijectors. combine (nsc. mask, y1, x2, x3)
67
60
end
68
61
69
- function (nsl:: NeuralSplineCoupling )(x:: AbstractVector )
70
- return Bijectors. transform (nsl, x)
62
+ function Bijectors. with_logabsdet_jacobian (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
63
+ x1, x2, x3 = Bijectors. partition (nsc. mask, x)
64
+ # instantiate rqs knots and derivatives
65
+ px, py, dydx = get_nsc_params (nsc, x2)
66
+ x1 = _ensure_matrix (x1)
67
+ y1, logjac = MonotonicSplines. rqs_forward (x1, px, py, dydx)
68
+ return Bijectors. combine (nsc. mask, y1, x2, x3), logjac isa Real ? logjac : vec (logjac)
71
69
end
72
70
73
- # define logabsdetjac
74
- function Bijectors. logabsdetjac (nsl:: NeuralSplineCoupling , x:: AbstractVector )
75
- x_1, x_2, _ = Bijectors. partition (nsl. mask, x)
76
- rqs = instantiate_rqs (nsl, x_2)
77
- logjac = logabsdetjac (rqs, x_1)
78
- return logjac
71
+ function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVecOrMat )
72
+ nsc = insl. orig
73
+ y1, y2, y3 = partition (nsc. mask, y)
74
+ px, py, dydx = get_nsc_params (nsc, y2)
75
+ y1 = _ensure_matrix (y1)
76
+ x1, _ = MonotonicSplines. rqs_inverse (y1, px, py, dydx)
77
+ return Bijectors. combine (nsc. mask, x1, y2, y3)
79
78
end
80
79
81
- function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVector )
82
- nsl = insl. orig
83
- y1, y2, _ = partition (nsl. mask, y)
84
- rqs = instantiate_rqs (nsl, y2)
85
- logjac = logabsdetjac (Inverse (rqs), y1)
86
- return logjac
80
+ function Bijectors. with_logabsdet_jacobian (insl:: Inverse{<:NeuralSplineCoupling} , y:: AbstractVecOrMat )
81
+ nsc = insl. orig
82
+ y1, y2, y3 = partition (nsc. mask, y)
83
+ px, py, dydx = get_nsc_params (nsc, y2)
84
+ y1 = _ensure_matrix (y1)
85
+ x1, logjac = MonotonicSplines. rqs_inverse (y1, px, py, dydx)
86
+ return Bijectors. combine (nsc. mask, x1, y2, y3), logjac isa Real ? logjac : vec (logjac)
87
87
end
88
88
89
- function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineCoupling , x:: AbstractVector )
90
- x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
91
- rqs = instantiate_rqs (nsl, x_2)
92
- y_1, logjac = with_logabsdet_jacobian (rqs, x_1)
93
- return Bijectors. combine (nsl. mask, y_1, x_2, x_3), logjac
89
+ function (nsc:: NeuralSplineCoupling )(x:: AbstractVecOrMat )
90
+ return Bijectors. transform (nsc, x)
94
91
end
95
92
96
93
97
94
"""
98
95
NSF_layer(dims, hdims; paramtype = Float64)
99
-
100
96
Default constructor of single layer of Neural Spline Flow (NSF)
101
97
which is a composition of 2 neural spline coupling transformations with complementary masks.
102
98
The masking strategy is odd-even masking.
103
-
104
99
# Arguments
105
100
- `dims::Int`: dimension of the problem
106
101
- `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
107
102
- `K::Int`: number of knots
108
103
- `B::AbstractFloat`: bound of the knots
109
-
110
104
# Keyword Arguments
111
105
- `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
112
-
113
106
# Returns
114
107
- A `Bijectors.Bijector` representing the NSF layer.
115
108
"""
0 commit comments