Skip to content

Commit a2f6fbe

Browse files
committed
add AD tests for realnvp elbo
1 parent 55fb607 commit a2f6fbe

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

test/ad.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,50 @@ end
7373
end
7474
end
7575
end
76+
77+
78+
@testset "AD for ELBO on realnvp" begin
79+
@testset "$at" for at in [
80+
ADTypes.AutoZygote(),
81+
ADTypes.AutoForwardDiff(),
82+
ADTypes.AutoReverseDiff(; compile=false),
83+
ADTypes.AutoEnzyme(;
84+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
85+
function_annotation=Enzyme.Const,
86+
),
87+
ADTypes.AutoMooncake(; config=Mooncake.Config()),
88+
]
89+
@testset "$T" for T in [Float32, Float64]
90+
μ = 10 * ones(T, 2)
91+
Σ = Diagonal(4 * ones(T, 2))
92+
target = MvNormal(μ, Σ)
93+
logp(z) = logpdf(target, z)
94+
95+
# necessary for Zygote/mooncake to differentiate through the flow
96+
# prevent updating params of q0
97+
@leaf MvNormal
98+
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
99+
flow = realnvp(q₀, [8, 8], 3; paramtype=T)
100+
101+
θ, re = Optimisers.destructure(flow)
102+
103+
# check grad computation for elbo
104+
function loss(θ, rng, logp, sample_per_iter)
105+
return -NormalizingFlows.elbo_batch(rng, re(θ), logp, sample_per_iter)
106+
end
107+
108+
rng = Random.default_rng()
109+
sample_per_iter = 10
110+
111+
prep = NormalizingFlows._prepare_gradient(
112+
loss, at, θ, rng, logp, sample_per_iter
113+
)
114+
value, grad = NormalizingFlows._value_and_gradient(
115+
loss, prep, at, θ, rng, logp, sample_per_iter
116+
)
117+
118+
@test value !== nothing
119+
@test all(grad .!= nothing)
120+
end
121+
end
122+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import DifferentiationInterface as DI
1111

1212
using Test
1313

14-
include("ad.jl")
1514
include("objectives.jl")
1615
include("interface.jl")
1716
include("flow.jl")
17+
include("ad.jl")

0 commit comments

Comments
 (0)