1
- # #################################
2
- # define neural spline layer using Bijectors.jl interface
3
- # ################################
4
1
"""
5
2
Neural Rational quadratic Spline layer
6
3
7
4
# References
8
5
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
9
6
"""
10
- struct NeuralSplineLayer {T,A<: Flux.Chain } <: Bijectors.Bijector
7
+ struct NeuralSplineCoupling {T,A<: Flux.Chain } <: Bijectors.Bijector
11
8
dim:: Int # dimension of input
12
9
K:: Int # number of knots
13
10
n_dims_transferred:: Int # number of dimensions that are transformed
14
- nn:: A # networks that parmaterize the knots and derivatives
15
11
B:: T # bound of the knots
12
+ nn:: A # networks that parmaterize the knots and derivatives
16
13
mask:: Bijectors.PartitionMask
17
14
end
18
15
19
- function NeuralSplineLayer (
16
+ function NeuralSplineCoupling (
20
17
dim:: T1 , # dimension of input
21
- hdims:: T1 , # dimension of hidden units for s and t
18
+ hdims:: AbstractVector{T1} , # dimension of hidden units for s and t
22
19
K:: T1 , # number of knots
23
20
B:: T2 , # bound of the knots
24
21
mask_idx:: AbstractVector{<:Int} , # index of dimensione that one wants to apply transformations on
25
- ) where {T1<: Int ,T2<: Real }
22
+ paramtype:: Type{T3} , # type of the parameters, e.g., Float64 or Float32
23
+ ) where {T1<: Int ,T2<: Real ,T3<: AbstractFloat }
26
24
num_of_transformed_dims = length (mask_idx)
27
25
input_dims = dim - num_of_transformed_dims
28
26
29
27
# output dim of the NN
30
28
output_dims = (3 K - 1 )* num_of_transformed_dims
31
29
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
32
- nn = mlp3 (input_dims, hdims, output_dims)
30
+ # todo: ensure type stability
31
+ nn = fnn (input_dims, hdims, output_dims; output_activation= nothing , paramtype= paramtype)
33
32
34
33
mask = Bijectors. PartitionMask (dim, mask_idx)
35
- return NeuralSplineLayer (dim, K, num_of_transformed_dims, nn, B , mask)
34
+ return NeuralSplineCoupling (dim, K, num_of_transformed_dims, B, nn , mask)
36
35
end
37
36
38
- @functor NeuralSplineLayer (nn,)
37
+ @functor NeuralSplineCoupling (nn,)
39
38
40
39
"""
41
40
Build a rational quadratic spline (RQS) from the nn output
42
41
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline
43
42
44
43
we just need to map the nn output to the knots and derivatives of the RQS
45
44
"""
46
- function instantiate_rqs (nsl:: NeuralSplineLayer , x:: AbstractVector )
45
+ function instantiate_rqs (nsl:: NeuralSplineCoupling , x:: AbstractVector )
47
46
K, B = nsl. K, nsl. B
48
47
nnoutput = reshape (nsl. nn (x), nsl. n_dims_transferred, :)
49
48
ws = @view nnoutput[:, 1 : K]
@@ -52,46 +51,100 @@ function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
52
51
return Bijectors. RationalQuadraticSpline (ws, hs, ds, B)
53
52
end
54
53
55
- function Bijectors. transform (nsl:: NeuralSplineLayer , x:: AbstractVector )
54
+ function Bijectors. transform (nsl:: NeuralSplineCoupling , x:: AbstractVector )
56
55
x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
57
56
# instantiate rqs knots and derivatives
58
57
rqs = instantiate_rqs (nsl, x_2)
59
58
y_1 = Bijectors. transform (rqs, x_1)
60
59
return Bijectors. combine (nsl. mask, y_1, x_2, x_3)
61
60
end
62
61
63
- function Bijectors. transform (insl:: Inverse{<:NeuralSplineLayer } , y:: AbstractVector )
62
+ function Bijectors. transform (insl:: Inverse{<:NeuralSplineCoupling } , y:: AbstractVector )
64
63
nsl = insl. orig
65
64
y1, y2, y3 = partition (nsl. mask, y)
66
65
rqs = instantiate_rqs (nsl, y2)
67
66
x1 = Bijectors. transform (Inverse (rqs), y1)
68
67
return Bijectors. combine (nsl. mask, x1, y2, y3)
69
68
end
70
69
71
- function (nsl:: NeuralSplineLayer )(x:: AbstractVector )
70
+ function (nsl:: NeuralSplineCoupling )(x:: AbstractVector )
72
71
return Bijectors. transform (nsl, x)
73
72
end
74
73
75
74
# define logabsdetjac
76
- function Bijectors. logabsdetjac (nsl:: NeuralSplineLayer , x:: AbstractVector )
75
+ function Bijectors. logabsdetjac (nsl:: NeuralSplineCoupling , x:: AbstractVector )
77
76
x_1, x_2, _ = Bijectors. partition (nsl. mask, x)
78
77
rqs = instantiate_rqs (nsl, x_2)
79
78
logjac = logabsdetjac (rqs, x_1)
80
79
return logjac
81
80
end
82
81
83
- function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineLayer } , y:: AbstractVector )
82
+ function Bijectors. logabsdetjac (insl:: Inverse{<:NeuralSplineCoupling } , y:: AbstractVector )
84
83
nsl = insl. orig
85
84
y1, y2, _ = partition (nsl. mask, y)
86
85
rqs = instantiate_rqs (nsl, y2)
87
86
logjac = logabsdetjac (Inverse (rqs), y1)
88
87
return logjac
89
88
end
90
89
91
- function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineLayer , x:: AbstractVector )
90
+ function Bijectors. with_logabsdet_jacobian (nsl:: NeuralSplineCoupling , x:: AbstractVector )
92
91
x_1, x_2, x_3 = Bijectors. partition (nsl. mask, x)
93
92
rqs = instantiate_rqs (nsl, x_2)
94
93
y_1, logjac = with_logabsdet_jacobian (rqs, x_1)
95
94
return Bijectors. combine (nsl. mask, y_1, x_2, x_3), logjac
96
95
end
97
96
97
+
98
+ """
99
+ NSF_layer(dims, hdims; paramtype = Float64)
100
+
101
+ Default constructor of single layer of Neural Spline Flow (NSF)
102
+ which is a composition of 2 neural spline coupling transformations with complementary masks.
103
+ The masking strategy is odd-even masking.
104
+
105
+ # Arguments
106
+ - `dims::Int`: dimension of the problem
107
+ - `hdims::AbstractVector{Int}`: dimension of hidden units for s and t
108
+ - `K::Int`: number of knots
109
+ - `B::AbstractFloat`: bound of the knots
110
+
111
+ # Keyword Arguments
112
+ - `paramtype::Type{T} = Float64`: type of the parameters, defaults to `Float64`
113
+
114
+ # Returns
115
+ - A `Bijectors.Bijector` representing the NSF layer.
116
+ """
117
+ function NSF_layer (
118
+ dims:: T1 , # dimension of problem
119
+ hdims:: AbstractVector{T1} , # dimension of hidden units for nn
120
+ K:: T1 , # number of knots
121
+ B:: T2 ; # bound of the knots
122
+ paramtype:: Type{T2} = Float64, # type of the parameters
123
+ ) where {T1<: Int ,T2<: AbstractFloat }
124
+
125
+ mask_idx1 = 1 : 2 : dims
126
+ mask_idx2 = 2 : 2 : dims
127
+
128
+ # by default use the odd-even masking strategy
129
+ nsf1 = NeuralSplineCoupling (dims, hdims, K, B, mask_idx1, paramtype)
130
+ nsf2 = NeuralSplineCoupling (dims, hdims, K, B, mask_idx2, paramtype)
131
+ return reduce (∘ , (nsf1, nsf2))
132
+ end
133
+
134
+ function nsf (
135
+ q0:: Distribution{Multivariate,Continuous} ,
136
+ hdims:: AbstractVector{Int} , # dimension of hidden units for s and t
137
+ K:: Int ,
138
+ B:: T ,
139
+ nlayers:: Int ; # number of RealNVP_layer
140
+ paramtype:: Type{T} = Float64, # type of the parameters
141
+ ) where {T<: AbstractFloat }
142
+
143
+ dims = length (q0) # dimension of the reference distribution == dim of the problem
144
+ Ls = [NSF_layer (dims, hdims, K, B; paramtype= paramtype) for _ in 1 : nlayers]
145
+ create_flow (Ls, q0)
146
+ end
147
+
148
+ nsf (q0; paramtype:: Type{T} = Float64) where {T<: AbstractFloat } = nsf (
149
+ q0, [32 , 32 ], 10 , 30 * one (T), 10 ; paramtype= paramtype
150
+ )
0 commit comments