Skip to content

Commit fb6b80b

Browse files
committed
add planar and radial flow; updating docs
1 parent 4dec51a commit fb6b80b

File tree

7 files changed

+209
-20
lines changed

7 files changed

+209
-20
lines changed

docs/src/api.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ NormalizingFlows.optimize
5757

5858
## Available Flows
5959

60-
`NormalizingFlows.jl` provides two commonly used normalizing flows: `RealNVP` and
61-
`Neural Spline Flow (NSF)`.
60+
`NormalizingFlows.jl` provides two commonly used normalizing flows---`RealNVP` and
61+
`Neural Spline Flow (NSF)`---and two simple flows---`Planar Flow` and `Radial Flow`.
6262

6363
### RealNVP (Affine Coupling Flow)
6464

@@ -78,7 +78,14 @@ NormalizingFlows.NSF_layer
7878
NormalizingFlows.NeuralSplineCoupling
7979
```
8080

81-
## Utility Functions
81+
#### Planar and Radial Flows
82+
83+
```@docs
84+
NormalizingFlows.planarflow
85+
NormalizingFlows.radialflow
86+
```
87+
88+
## Utility Functions
8289

8390
```@docs
8491
NormalizingFlows.create_flow

example/demo_RealNVP.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
4848
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
4949
flow_trained, stats, _ = train_flow(
5050
rng,
51-
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
51+
elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
5252
flow,
5353
logp,
5454
sample_per_iter;

example/demo_planar_flow.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,9 @@ logp = Base.Fix1(logpdf, target)
2020
######################################
2121
# setup planar flow
2222
######################################
23-
function create_planar_flow(n_layers::Int, q₀)
24-
d = length(q₀)
25-
Ls = [PlanarLayer(d) for _ in 1:n_layers]
26-
ts = reduce(, Ls)
27-
return transformed(q₀, ts)
28-
end
29-
3023
@leaf MvNormal
3124
q0 = MvNormal(zeros(T, 2), ones(T, 2))
32-
flow = create_planar_flow(10, q0)
25+
flow = planarflow(q0, 10; paramtype=T)
3326
flow_untrained = deepcopy(flow)
3427

3528
######################################

example/demo_radial_flow.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,10 @@ logp = Base.Fix1(logpdf, target)
1919
######################################
2020
# setup radial flow
2121
######################################
22-
function create_radial_flow(n_layers::Int, q₀)
23-
d = length(q₀)
24-
Ls = [RadialLayer(d) for _ in 1:n_layers]
25-
ts = reduce(, Ls)
26-
return transformed(q₀, ts)
27-
end
28-
2922
# create a 10-layer radial flow
3023
@leaf MvNormal
3124
q0 = MvNormal(zeros(T, 2), ones(T, 2))
32-
flow = create_radial_flow(10, q0)
25+
flow = radialflow(q0, 10; paramtype=T)
3326

3427
flow_untrained = deepcopy(flow)
3528

src/NormalizingFlows.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,14 @@ end
129129

130130
# interface of contructing common flow layers
131131
include("flows/utils.jl")
132+
include("flows/planar_radial.jl")
132133
include("flows/realnvp.jl")
133134

134135
using MonotonicSplines
135136
include("flows/neuralspline.jl")
136137

137138
export create_flow
139+
export planarflow, radialflow
138140
export AffineCoupling, RealNVP_layer, realnvp
139141
export NeuralSplineCoupling, NSF_layer, nsf
140142

src/flows/planar_radial.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
planarflow(q0, nlayers; paramtype = Float64)
3+
4+
Construct a Planar Flow by stacking `nlayers` `Bijectors.PlanarLayer` blocks
5+
on top of a base distribution `q0`.
6+
7+
Arguments
8+
- `q0::Distribution{Multivariate,Continuous}`: base distribution (e.g., `MvNormal(zeros(d), I)`).
9+
- `nlayers::Int`: number of planar layers to compose.
10+
11+
Keyword Arguments
12+
- `paramtype::Type{T} = Float64`: parameter element type (use `Float32` for GPU friendliness).
13+
14+
Returns
15+
- `Bijectors.TransformedDistribution` representing the planar flow.
16+
17+
Example
18+
- `q0 = MvNormal(zeros(2), I); flow = planarflow(q0, 10)`
19+
- `x = rand(flow, 128); lp = logpdf(flow, x)`
20+
"""
21+
function planarflow(
22+
q0::Distribution{Multivariate,Continuous},
23+
nlayers::Int;
24+
paramtype::Type{T} = Float64,
25+
) where {T<:AbstractFloat}
26+
dim = length(q0)
27+
Ls = [Flux._paramtype(paramtype, Bijectors.PlanarLayer(dim)) for _ in 1:nlayers]
28+
return create_flow(Ls, q0)
29+
end
30+
31+
32+
"""
33+
radialflow(q0, nlayers; paramtype = Float64)
34+
35+
Construct a Radial Flow by stacking `nlayers` `Bijectors.RadialLayer` blocks
36+
on top of a base distribution `q0`.
37+
38+
Arguments
39+
- `q0::Distribution{Multivariate,Continuous}`: base distribution (e.g., `MvNormal(zeros(d), I)`).
40+
- `nlayers::Int`: number of radial layers to compose.
41+
42+
Keyword Arguments
43+
- `paramtype::Type{T} = Float64`: parameter element type (use `Float32` for GPU friendliness).
44+
45+
Returns
46+
- `Bijectors.TransformedDistribution` representing the radial flow.
47+
48+
Example
49+
- `q0 = MvNormal(zeros(2), I); flow = radialflow(q0, 6)`
50+
- `x = rand(flow); lp = logpdf(flow, x)`
51+
"""
52+
function radialflow(
53+
q0::Distribution{Multivariate,Continuous},
54+
nlayers::Int;
55+
paramtype::Type{T} = Float64,
56+
) where {T<:AbstractFloat}
57+
dim = length(q0)
58+
Ls = [Flux._paramtype(paramtype, Bijectors.RadialLayer(dim)) for _ in 1:nlayers]
59+
return create_flow(Ls, q0)
60+
end

test/flow.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,137 @@ end
128128
end
129129
end
130130
end
131+
132+
133+
134+
@testset "Planar flow" begin
135+
Random.seed!(123)
136+
137+
dim = 5
138+
nlayers = 10
139+
for T in [Float32, Float64]
140+
# Create a nsf
141+
q₀ = MvNormal(zeros(T, dim), I)
142+
@leaf MvNormal
143+
144+
flow = NormalizingFlows.planarflow(q₀, nlayers; paramtype=T)
145+
146+
@testset "Sampling and density estimation for type: $T" begin
147+
ys = rand(flow, 100)
148+
ℓs = logpdf(flow, ys)
149+
150+
@test size(ys) == (dim, 100)
151+
@test length(ℓs) == 100
152+
153+
@test eltype(ys) == T
154+
@test eltype(ℓs) == T
155+
end
156+
157+
158+
@testset "Inverse compatibility for type: $T" begin
159+
x = rand(q₀)
160+
y, lj_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x)
161+
x_reconstructed, lj_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y)
162+
163+
@test x x_reconstructed rtol=1e-4
164+
@test lj_fwd -lj_bwd rtol=1e-4
165+
166+
x_batch = rand(q₀, 10)
167+
y_batch, ljs_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x_batch)
168+
x_batch_reconstructed, ljs_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y_batch)
169+
170+
@test x_batch x_batch_reconstructed rtol=1e-4
171+
@test ljs_fwd -ljs_bwd rtol=1e-4
172+
end
173+
174+
175+
@testset "ELBO test for type: $T" begin
176+
μ = randn(T, dim)
177+
Σ = Diagonal(rand(T, dim) .+ T(1e-3))
178+
target = MvNormal(μ, Σ)
179+
logp(z) = logpdf(target, z)
180+
181+
# Compute ELBO
182+
batchsize = 64
183+
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
184+
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
185+
186+
# test when batchsize == 1
187+
batchsize_single = 1
188+
elbo_value_single = elbo(Random.default_rng(), flow, logp, batchsize_single)
189+
190+
# test elbo_value is not NaN and not Inf
191+
@test isfinite(elbo_value)
192+
@test isfinite(elbo_batch_value)
193+
@test isfinite(elbo_value_single)
194+
end
195+
end
196+
end
197+
198+
199+
200+
@testset "Radial flow" begin
201+
Random.seed!(123)
202+
203+
dim = 5
204+
nlayers = 10
205+
for T in [Float32, Float64]
206+
# Create a nsf
207+
q₀ = MvNormal(zeros(T, dim), I)
208+
@leaf MvNormal
209+
210+
flow = NormalizingFlows.radialflow(q₀, nlayers; paramtype=T)
211+
212+
@testset "Sampling and density estimation for type: $T" begin
213+
ys = rand(flow, 100)
214+
ℓs = logpdf(flow, ys)
215+
216+
@test size(ys) == (dim, 100)
217+
@test length(ℓs) == 100
218+
219+
@test eltype(ys) == T
220+
@test eltype(ℓs) == T
221+
end
222+
223+
224+
@testset "Inverse compatibility for type: $T" begin
225+
x = rand(q₀)
226+
y, lj_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x)
227+
x_reconstructed, lj_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y)
228+
229+
@test x x_reconstructed rtol=1e-4
230+
@test lj_fwd -lj_bwd rtol=1e-4
231+
232+
x_batch = rand(q₀, 10)
233+
y_batch, ljs_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x_batch)
234+
x_batch_reconstructed, ljs_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y_batch)
235+
236+
@test x_batch x_batch_reconstructed rtol=1e-4
237+
@test ljs_fwd -ljs_bwd rtol=1e-4
238+
end
239+
240+
241+
@testset "ELBO test for type: $T" begin
242+
μ = randn(T, dim)
243+
Σ = Diagonal(rand(T, dim) .+ T(1e-3))
244+
target = MvNormal(μ, Σ)
245+
logp(z) = logpdf(target, z)
246+
247+
# Compute ELBO
248+
batchsize = 64
249+
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
250+
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
251+
252+
# test when batchsize == 1
253+
batchsize_single = 1
254+
elbo_value_single = elbo(Random.default_rng(), flow, logp, batchsize_single)
255+
256+
# test elbo_value is not NaN and not Inf
257+
@test isfinite(elbo_value)
258+
@test isfinite(elbo_batch_value)
259+
@test isfinite(elbo_value_single)
260+
end
261+
end
262+
end
263+
264+

0 commit comments

Comments
 (0)