Skip to content

Commit 48bc3d3

Browse files
committed
add test for nsf
1 parent 2903b83 commit 48bc3d3

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test/flow.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,72 @@
3939
end
4040

4141

42+
@testset "ELBO test for type: $T" begin
43+
μ = randn(T, dim)
44+
Σ = Diagonal(rand(T, dim) .+ T(1e-3))
45+
target = MvNormal(μ, Σ)
46+
logp(z) = logpdf(target, z)
47+
48+
# Define a simple log-likelihood function
49+
logp(z) = logpdf(q₀, z)
50+
51+
# Compute ELBO
52+
batchsize = 64
53+
elbo_value = elbo(Random.default_rng(), flow, logp, batchsize)
54+
elbo_batch_value = elbo_batch(Random.default_rng(), flow, logp, batchsize)
55+
56+
# test elbo_value is not NaN and not Inf
57+
@test !isnan(elbo_value)
58+
@test !isinf(elbo_value)
59+
@test !isnan(elbo_batch_value)
60+
@test !isinf(elbo_batch_value)
61+
end
62+
63+
#todo add tests for ad
64+
end
65+
end
66+
67+
@testset "Neural Spline flow" begin
68+
Random.seed!(123)
69+
70+
dim = 5
71+
nlayers = 2
72+
hdims = [32, 32]
73+
for T in [Float32, Float64]
74+
# Create a RealNVP flow
75+
q₀ = MvNormal(zeros(T, dim), I)
76+
@leaf MvNormal
77+
flow = NormalizingFlows.nsf(q₀; paramtype=T)
78+
79+
@testset "Sampling and density estimation for type: $T" begin
80+
ys = rand(flow, 100)
81+
ℓs = logpdf(flow, ys)
82+
83+
@test size(ys) == (dim, 100)
84+
@test length(ℓs) == 100
85+
86+
@test eltype(ys) == T
87+
@test eltype(ℓs) == T
88+
end
89+
90+
91+
@testset "Inverse compatibility for type: $T" begin
92+
x = rand(q₀)
93+
y, lj_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x)
94+
x_reconstructed, lj_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y)
95+
96+
@test x x_reconstructed rtol=1e-6
97+
@test lj_fwd -lj_bwd rtol=1e-6
98+
99+
x_batch = rand(q₀, 10)
100+
y_batch, ljs_fwd = Bijectors.with_logabsdet_jacobian(flow.transform, x_batch)
101+
x_batch_reconstructed, ljs_bwd = Bijectors.with_logabsdet_jacobian(inverse(flow.transform), y_batch)
102+
103+
@test x_batch x_batch_reconstructed rtol=1e-6
104+
@test ljs_fwd -ljs_bwd rtol=1e-6
105+
end
106+
107+
42108
@testset "ELBO test for type: $T" begin
43109
μ = randn(T, dim)
44110
Σ = Diagonal(rand(T, dim) .+ T(1e-3))

0 commit comments

Comments
 (0)