Skip to content

Commit 9494de1

Browse files
committed
add ad test for nsf
1 parent 48bc3d3 commit 9494de1

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

test/ad.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,50 @@ end
120120
end
121121
end
122122
end
123+
124+
@testset "AD for ELBO on NSF" begin
125+
@testset "$at" for at in [
126+
ADTypes.AutoZygote(),
127+
ADTypes.AutoForwardDiff(),
128+
ADTypes.AutoReverseDiff(; compile=false),
129+
ADTypes.AutoEnzyme(;
130+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
131+
function_annotation=Enzyme.Const,
132+
),
133+
# it doesn't work with mooncake yet
134+
ADTypes.AutoMooncake(; config=Mooncake.Config()),
135+
]
136+
@testset "$T" for T in [Float32, Float64]
137+
μ = 10 * ones(T, 2)
138+
Σ = Diagonal(4 * ones(T, 2))
139+
target = MvNormal(μ, Σ)
140+
logp(z) = logpdf(target, z)
141+
142+
# necessary for Zygote/mooncake to differentiate through the flow
143+
# prevent updating params of q0
144+
@leaf MvNormal
145+
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
146+
flow = realnvp(q₀, [8, 8], 3; paramtype=T)
147+
148+
θ, re = Optimisers.destructure(flow)
149+
150+
# check grad computation for elbo
151+
function loss(θ, rng, logp, sample_per_iter)
152+
return -NormalizingFlows.elbo_batch(rng, re(θ), logp, sample_per_iter)
153+
end
154+
155+
rng = Random.default_rng()
156+
sample_per_iter = 10
157+
158+
prep = NormalizingFlows._prepare_gradient(
159+
loss, at, θ, rng, logp, sample_per_iter
160+
)
161+
value, grad = NormalizingFlows._value_and_gradient(
162+
loss, prep, at, θ, rng, logp, sample_per_iter
163+
)
164+
165+
@test value !== nothing
166+
@test all(grad .!= nothing)
167+
end
168+
end
169+
end

0 commit comments

Comments
 (0)