Skip to content

Commit 2431e57

Browse files
committed
incorpoerate comments from red-portal and sunxd3
1 parent 8a04147 commit 2431e57

File tree

5 files changed

+40
-72
lines changed

5 files changed

+40
-72
lines changed

docs/src/example.md

Whitespace-only changes.

src/flows/neuralspline.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
NeuralSplineCoupling(dim, hdims, K, B, mask_idx, paramtype)
3-
NeuralSplineCoupling(dim, K, n_dims_transferred, B, nn, mask)
3+
NeuralSplineCoupling(dim, K, n_dims_transformed, B, nn, mask)
44
55
Neural Rational Quadratic Spline (RQS) coupling bijector [^DBMP2019].
66
@@ -19,7 +19,7 @@ Keyword Arguments
1919
- `paramtype::Type{<:AbstractFloat}`: parameter element type.
2020
2121
Fields
22-
- `nn::Flux.Chain`: conditioner that outputs all spline params for all transformed dims.
22+
- `nn::Flux.Chain`: conditioner that outputs all spline params for all transformed dim.
2323
- `mask::Bijectors.PartitionMask`: partition specification.
2424
2525
Notes
@@ -35,9 +35,9 @@ and log-determinant computations.
3535
struct NeuralSplineCoupling{T,A<:Flux.Chain} <: Bijectors.Bijector
3636
dim::Int # dimension of input
3737
K::Int # number of knots
38-
n_dims_transferred::Int # number of dimensions that are transformed
38+
n_dims_transformed::Int # number of dimensions that are transformed
3939
B::T # bound of the knots
40-
nn::A # networks that parmaterize the knots and derivatives
40+
nn::A # networks that parameterize the knots and derivatives
4141
mask::Bijectors.PartitionMask
4242
end
4343

@@ -46,13 +46,12 @@ function NeuralSplineCoupling(
4646
hdims::AbstractVector{T1}, # dimension of hidden units for s and t
4747
K::T1, # number of knots
4848
B::T2, # bound of the knots
49-
mask_idx::AbstractVector{T1}, # index of dimensione that one wants to apply transformations on
49+
mask_idx::AbstractVector{T1}, # indices of the transformed dimensions
5050
paramtype::Type{T2}, # type of the parameters, e.g., Float64 or Float32
5151
) where {T1<:Int,T2<:AbstractFloat}
5252
num_of_transformed_dims = length(mask_idx)
5353
input_dims = dim - num_of_transformed_dims
5454

55-
# output dim of the NN
5655
output_dims = (3K - 1)*num_of_transformed_dims
5756
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
5857
nn = fnn(input_dims, hdims, output_dims; output_activation=nothing, paramtype=paramtype)
@@ -66,7 +65,7 @@ end
6665
function get_nsc_params(nsc::NeuralSplineCoupling, x::AbstractVecOrMat)
6766
nnoutput = nsc.nn(x)
6867
px, py, dydx = MonotonicSplines.rqs_params_from_nn(
69-
nnoutput, nsc.n_dims_transferred, nsc.B
68+
nnoutput, nsc.n_dims_transformed, nsc.B
7069
)
7170
return px, py, dydx
7271
end
@@ -146,13 +145,13 @@ end
146145

147146

148147
"""
149-
NSF_layer(dims, hdims, K, B; paramtype = Float64)
148+
NSF_layer(dim, hdims, K, B; paramtype = Float64)
150149
151150
Build a single Neural Spline Flow (NSF) layer by composing two
152151
`NeuralSplineCoupling` bijectors with complementary odd–even masks.
153152
154153
Arguments
155-
- `dims::Int`: dimensionality of the problem.
154+
- `dim::Int`: dimensionality of the problem.
156155
- `hdims::AbstractVector{Int}`: hidden sizes of the conditioner network.
157156
- `K::Int`: number of spline knots.
158157
- `B::AbstractFloat`: spline boundary.
@@ -168,19 +167,19 @@ Example
168167
- `y = layer(randn(4, 32))`
169168
"""
170169
function NSF_layer(
171-
dims::T1, # dimension of problem
170+
dim::T1, # dimension of problem
172171
hdims::AbstractVector{T1}, # dimension of hidden units for nn
173172
K::T1, # number of knots
174173
B::T2; # bound of the knots
175174
paramtype::Type{T2} = Float64, # type of the parameters
176175
) where {T1<:Int,T2<:AbstractFloat}
177176

178-
mask_idx1 = 1:2:dims
179-
mask_idx2 = 2:2:dims
177+
mask_idx1 = 1:2:dim
178+
mask_idx2 = 2:2:dim
180179

181180
# by default use the odd-even masking strategy
182-
nsf1 = NeuralSplineCoupling(dims, hdims, K, B, mask_idx1, paramtype)
183-
nsf2 = NeuralSplineCoupling(dims, hdims, K, B, mask_idx2, paramtype)
181+
nsf1 = NeuralSplineCoupling(dim, hdims, K, B, mask_idx1, paramtype)
182+
nsf2 = NeuralSplineCoupling(dim, hdims, K, B, mask_idx2, paramtype)
184183
return reduce(, (nsf1, nsf2))
185184
end
186185

@@ -205,11 +204,11 @@ Keyword Arguments
205204
Returns
206205
- `Bijectors.TransformedDistribution` representing the NSF flow.
207206
208-
Notes:
209-
- Under the hood, `nsf` relies on the rational quadratic spline function implememented in
210-
`MonotonicSplines.jl` for performance reasons. `MonotonicSplines.jl` uses
211-
`KernelAbstractions.jl` to support batched operations.
212-
Because of this, so far `nsf` only supports `Zygote` as the AD type.
207+
!!! note
208+
Under the hood, `nsf` relies on the rational quadratic spline function implememented in
209+
`MonotonicSplines.jl` for performance reasons. `MonotonicSplines.jl` uses
210+
`KernelAbstractions.jl` to support batched operations.
211+
Because of this, so far `nsf` only supports `Zygote` as the AD type.
213212
214213
215214
Example
@@ -225,8 +224,8 @@ function nsf(
225224
paramtype::Type{T} = Float64, # type of the parameters
226225
) where {T<:AbstractFloat}
227226

228-
dims = length(q0) # dimension of the reference distribution == dim of the problem
229-
Ls = [NSF_layer(dims, hdims, K, B; paramtype=paramtype) for _ in 1:nlayers]
227+
dim = length(q0) # dimension of the reference distribution == dim of the problem
228+
Ls = [NSF_layer(dim, hdims, K, B; paramtype=paramtype) for _ in 1:nlayers]
230229
create_flow(Ls, q0)
231230
end
232231

src/flows/realnvp.jl

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,12 @@ struct AffineCoupling <: Bijectors.Bijector
3737
t::Flux.Chain
3838
end
3939

40-
# let params track field s and t
4140
@functor AffineCoupling (s, t)
4241

4342
function AffineCoupling(
4443
dim::Int, # dimension of the problem
4544
hdims::AbstractVector{Int}, # dimension of hidden units for s and t
46-
mask_idx::AbstractVector{Int}, # index of dimensione that one wants to apply transformations on
45+
mask_idx::AbstractVector{Int}, # indices of the transformed dimensions
4746
paramtype::Type{T}
4847
) where {T<:AbstractFloat}
4948
cdims = length(mask_idx) # dimension of parts used to construct coupling law
@@ -110,37 +109,6 @@ function Bijectors.with_logabsdet_jacobian(
110109
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
111110
end
112111

113-
###################
114-
# an equivalent definition of AffineCoupling using Bijectors.Coupling
115-
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
116-
###################
117-
118-
# struct AffineCoupling <: Bijectors.Bijector
119-
# dim::Int
120-
# mask::Bijectors.PartitionMask
121-
# s::Flux.Chain
122-
# t::Flux.Chain
123-
# end
124-
125-
# # let params track field s and t
126-
# @functor AffineCoupling (s, t)
127-
128-
# function AffineCoupling(dim, mask, s, t)
129-
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
130-
# end
131-
132-
# function AffineCoupling(
133-
# dim::Int, # dimension of input
134-
# hdims::Int, # dimension of hidden units for s and t
135-
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
136-
# )
137-
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
138-
# s = mlp3(cdims, hdims, cdims)
139-
# t = mlp3(cdims, hdims, cdims)
140-
# mask = PartitionMask(dim, mask_idx)
141-
# return AffineCoupling(dim, mask, s, t)
142-
# end
143-
144112
"""
145113
RealNVP_layer(dims, hdims; paramtype = Float64)
146114

src/flows/utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ Construct a normalizing flow by composing the provided bijector layers and
88
attaching them to the base distribution `q0`.
99
1010
- `layers`: an iterable of `Bijectors.Bijector` objects that are composed in order
11-
(left-to-right) via function composition.
11+
(left-to-right) via function composition
12+
(for instance, if `layers = [l1, l2, l3]`, the flow will be `l3∘l2∘l1(q0)`).
1213
- `q0`: the base distribution (e.g., `MvNormal(zeros(d), I)`).
1314
1415
Returns a `Bijectors.TransformedDistribution` representing the resulting flow.
1516
1617
Example
1718
18-
using Distributions
19+
using Distributions, Bijectors, LinearAlgebra
1920
q0 = MvNormal(zeros(2), I)
2021
flow = create_flow((Bijectors.Scale([1.0, 2.0]), Bijectors.Shift([0.0, 1.0])), q0)
2122
"""
@@ -77,7 +78,7 @@ function fnn(
7778
) where {T<:AbstractFloat}
7879
# Create a chain of dense layers
7980
# First layer
80-
layers = Any[Flux.Dense(input_dim, hidden_dims[1], inlayer_activation)]
81+
layers = [Flux.Dense(input_dim, hidden_dims[1], inlayer_activation)]
8182

8283
# Hidden layers
8384
for i in 1:(length(hidden_dims) - 1)
@@ -96,4 +97,4 @@ function fnn(
9697

9798
m = Chain(layers...)
9899
return Flux._paramtype(paramtype, m)
99-
end
100+
end

test/flow.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@
4545
target = MvNormal(μ, Σ)
4646
logp(z) = logpdf(target, z)
4747

48-
# Define a simple log-likelihood function
49-
logp(z) = logpdf(q₀, z)
50-
5148
# Compute ELBO
5249
batchsize = 64
5350
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
5451
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
5552

53+
# test when batchsize == 1
54+
batchsize_single = 1
55+
elbo_value_single = elbo(Random.default_rng(), flow, logp, batchsize_single)
56+
5657
# test elbo_value is not NaN and not Inf
57-
@test !isnan(elbo_value)
58-
@test !isinf(elbo_value)
59-
@test !isnan(elbo_batch_value)
60-
@test !isinf(elbo_batch_value)
58+
@test isfinite(elbo_value)
59+
@test isfinite(elbo_batch_value)
60+
@test isfinite(elbo_value_single)
6161
end
6262
end
6363
end
@@ -112,19 +112,19 @@ end
112112
target = MvNormal(μ, Σ)
113113
logp(z) = logpdf(target, z)
114114

115-
# Define a simple log-likelihood function
116-
logp(z) = logpdf(q₀, z)
117-
118115
# Compute ELBO
119116
batchsize = 64
120117
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
121118
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
122119

120+
# test when batchsize == 1
121+
batchsize_single = 1
122+
elbo_value_single = elbo(Random.default_rng(), flow, logp, batchsize_single)
123+
123124
# test elbo_value is not NaN and not Inf
124-
@test !isnan(elbo_value)
125-
@test !isinf(elbo_value)
126-
@test !isnan(elbo_batch_value)
127-
@test !isinf(elbo_batch_value)
125+
@test isfinite(elbo_value)
126+
@test isfinite(elbo_batch_value)
127+
@test isfinite(elbo_value_single)
128128
end
129129
end
130130
end

0 commit comments

Comments
 (0)