@@ -8,19 +8,19 @@ Neural Rational quadratic Spline layer
8
8
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
9
9
"""
10
10
struct NeuralSplineLayer{T,A<: Flux.Chain } <: Bijectors.Bijector
11
- dim:: Int # dimension of input
12
- K:: Int # number of knots
13
- n_dims_transferred:: Int # number of dimensions that are transformed
14
- nn:: A # networks that parmaterize the knots and derivatives
15
- B:: T # bound of the knots
11
+ dim:: Int # dimension of input
12
+ K:: Int # number of knots
13
+ n_dims_transferred:: Int # number of dimensions that are transformed
14
+ nn:: A # networks that parmaterize the knots and derivatives
15
+ B:: T # bound of the knots
16
16
mask:: Bijectors.PartitionMask
17
17
end
18
18
19
19
function NeuralSplineLayer (
20
- dim:: T1 , # dimension of input
21
- hdims:: T1 , # dimension of hidden units for s and t
22
- K:: T1 , # number of knots
23
- B:: T2 , # bound of the knots
20
+ dim:: T1 , # dimension of input
21
+ hdims:: T1 , # dimension of hidden units for s and t
22
+ K:: T1 , # number of knots
23
+ B:: T2 , # bound of the knots
24
24
mask_idx:: AbstractVector{<:Int} , # index of dimensione that one wants to apply transformations on
25
25
) where {T1<: Int ,T2<: Real }
26
26
num_of_transformed_dims = length (mask_idx)
0 commit comments