Skip to content

Commit e4fa67b

Browse files
committed
add realnvp test
1 parent 731e657 commit e4fa67b

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

test/flow.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
@testset "RealNVP flow" begin
2+
Random.seed!(123)
3+
4+
dim = 5
5+
nlayers = 2
6+
hdims = [32, 32]
7+
for T in [Float32, Float64]
8+
# Create a RealNVP flow
9+
q₀ = MvNormal(zeros(T, dim), I)
10+
@leaf MvNormal
11+
flow = NormalizingFlows.realnvp(q₀, hdims, nlayers; paramtype=T)
12+
13+
@testset "Sampling and density estimation for type: $T" begin
14+
ys = rand(flow, 100)
15+
ℓs = logpdf(flow, ys)
16+
17+
@test size(ys) == (dim, 100)
18+
@test length(ℓs) == 100
19+
20+
@test eltype(ys) == T
21+
@test eltype(ℓs) == T
22+
end
23+
24+
25+
@testset "Inverse compatibility for type: $T" begin
26+
x = rand(q₀)
27+
y, lj_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x)
28+
x_reconstructed, lj_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y)
29+
30+
@test x x_reconstructed rtol=1e-6
31+
@test lj_fwd -lj_bwd rtol=1e-6
32+
33+
x_batch = rand(q₀, 10)
34+
y_batch, ljs_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x_batch)
35+
x_batch_reconstructed, ljs_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y_batch)
36+
37+
@test x_batch x_batch_reconstructed rtol=1e-6
38+
@test ljs_fwd -ljs_bwd rtol=1e-6
39+
end
40+
41+
42+
@testset "ELBO test for type: $T" begin
43+
μ = randn(T, dim)
44+
Σ = Diagonal(rand(T, dim) .+ T(1e-3))
45+
target = MvNormal(μ, Σ)
46+
logp(z) = logpdf(target, z)
47+
48+
# Define a simple log-likelihood function
49+
logp(z) = logpdf(q₀, z)
50+
51+
# Compute ELBO
52+
batchsize = 64
53+
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
54+
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
55+
56+
# test elbo_value is not NaN and not Inf
57+
@test !isnan(elbo_value)
58+
@test !isinf(elbo_value)
59+
@test !isnan(elbo_batch_value)
60+
@test !isinf(elbo_batch_value)
61+
end
62+
63+
#todo add tests for ad
64+
end
65+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ using Test
1414
include("ad.jl")
1515
include("objectives.jl")
1616
include("interface.jl")
17+
include("flow.jl")

0 commit comments

Comments
 (0)