11"""
22 NeuralSplineCoupling(dim, hdims, K, B, mask_idx, paramtype)
3- NeuralSplineCoupling(dim, K, n_dims_transferred , B, nn, mask)
3+ NeuralSplineCoupling(dim, K, n_dims_transformed , B, nn, mask)
44
55Neural Rational Quadratic Spline (RQS) coupling bijector [^DBMP2019].
66
@@ -19,7 +19,7 @@ Keyword Arguments
1919- `paramtype::Type{<:AbstractFloat}`: parameter element type.
2020
2121Fields
22- - `nn::Flux.Chain`: conditioner that outputs all spline params for all transformed dims .
22+ - `nn::Flux.Chain`: conditioner that outputs all spline params for all transformed dim .
2323- `mask::Bijectors.PartitionMask`: partition specification.
2424
2525Notes
@@ -35,9 +35,9 @@ and log-determinant computations.
3535struct NeuralSplineCoupling{T,A<: Flux.Chain } <: Bijectors.Bijector
3636 dim:: Int # dimension of input
3737 K:: Int # number of knots
38- n_dims_transferred :: Int # number of dimensions that are transformed
38+ n_dims_transformed :: Int # number of dimensions that are transformed
3939 B:: T # bound of the knots
40- nn:: A # networks that parmaterize the knots and derivatives
40+ nn:: A # networks that parameterize the knots and derivatives
4141 mask:: Bijectors.PartitionMask
4242end
4343
@@ -46,13 +46,12 @@ function NeuralSplineCoupling(
4646 hdims:: AbstractVector{T1} , # dimension of hidden units for s and t
4747 K:: T1 , # number of knots
4848 B:: T2 , # bound of the knots
49- mask_idx:: AbstractVector{T1} , # index of dimensione that one wants to apply transformations on
49+ mask_idx:: AbstractVector{T1} , # indices of the transformed dimensions
5050 paramtype:: Type{T2} , # type of the parameters, e.g., Float64 or Float32
5151) where {T1<: Int ,T2<: AbstractFloat }
5252 num_of_transformed_dims = length (mask_idx)
5353 input_dims = dim - num_of_transformed_dims
5454
55- # output dim of the NN
5655 output_dims = (3 K - 1 )* num_of_transformed_dims
5756 # one big mlp that outputs all the knots and derivatives for all the transformed dimensions
5857 nn = fnn (input_dims, hdims, output_dims; output_activation= nothing , paramtype= paramtype)
6665function get_nsc_params (nsc:: NeuralSplineCoupling , x:: AbstractVecOrMat )
6766 nnoutput = nsc. nn (x)
6867 px, py, dydx = MonotonicSplines. rqs_params_from_nn (
69- nnoutput, nsc. n_dims_transferred , nsc. B
68+ nnoutput, nsc. n_dims_transformed , nsc. B
7069 )
7170 return px, py, dydx
7271end
@@ -146,13 +145,13 @@ end
146145
147146
148147"""
149- NSF_layer(dims , hdims, K, B; paramtype = Float64)
148+ NSF_layer(dim , hdims, K, B; paramtype = Float64)
150149
151150Build a single Neural Spline Flow (NSF) layer by composing two
152151`NeuralSplineCoupling` bijectors with complementary odd–even masks.
153152
154153Arguments
155- - `dims ::Int`: dimensionality of the problem.
154+ - `dim ::Int`: dimensionality of the problem.
156155- `hdims::AbstractVector{Int}`: hidden sizes of the conditioner network.
157156- `K::Int`: number of spline knots.
158157- `B::AbstractFloat`: spline boundary.
@@ -168,19 +167,19 @@ Example
168167- `y = layer(randn(4, 32))`
169168"""
170169function NSF_layer (
171- dims :: T1 , # dimension of problem
170+ dim :: T1 , # dimension of problem
172171 hdims:: AbstractVector{T1} , # dimension of hidden units for nn
173172 K:: T1 , # number of knots
174173 B:: T2 ; # bound of the knots
175174 paramtype:: Type{T2} = Float64, # type of the parameters
176175) where {T1<: Int ,T2<: AbstractFloat }
177176
178- mask_idx1 = 1 : 2 : dims
179- mask_idx2 = 2 : 2 : dims
177+ mask_idx1 = 1 : 2 : dim
178+ mask_idx2 = 2 : 2 : dim
180179
181180 # by default use the odd-even masking strategy
182- nsf1 = NeuralSplineCoupling (dims , hdims, K, B, mask_idx1, paramtype)
183- nsf2 = NeuralSplineCoupling (dims , hdims, K, B, mask_idx2, paramtype)
181+ nsf1 = NeuralSplineCoupling (dim , hdims, K, B, mask_idx1, paramtype)
182+ nsf2 = NeuralSplineCoupling (dim , hdims, K, B, mask_idx2, paramtype)
184183 return reduce (∘ , (nsf1, nsf2))
185184end
186185
@@ -205,11 +204,11 @@ Keyword Arguments
205204Returns
206205- `Bijectors.TransformedDistribution` representing the NSF flow.
207206
208- Notes:
209- - Under the hood, `nsf` relies on the rational quadratic spline function implememented in
210- `MonotonicSplines.jl` for performance reasons. `MonotonicSplines.jl` uses
211- `KernelAbstractions.jl` to support batched operations.
212- Because of this, so far `nsf` only supports `Zygote` as the AD type.
207+ !!! note
208+ Under the hood, `nsf` relies on the rational quadratic spline function implememented in
209+ `MonotonicSplines.jl` for performance reasons. `MonotonicSplines.jl` uses
210+ `KernelAbstractions.jl` to support batched operations.
211+ Because of this, so far `nsf` only supports `Zygote` as the AD type.
213212
214213
215214Example
@@ -225,8 +224,8 @@ function nsf(
225224 paramtype:: Type{T} = Float64, # type of the parameters
226225) where {T<: AbstractFloat }
227226
228- dims = length (q0) # dimension of the reference distribution == dim of the problem
229- Ls = [NSF_layer (dims , hdims, K, B; paramtype= paramtype) for _ in 1 : nlayers]
227+ dim = length (q0) # dimension of the reference distribution == dim of the problem
228+ Ls = [NSF_layer (dim , hdims, K, B; paramtype= paramtype) for _ in 1 : nlayers]
230229 create_flow (Ls, q0)
231230end
232231
0 commit comments