|
120 | 120 | end
|
121 | 121 | end
|
122 | 122 | end
|
| 123 | + |
| 124 | +@testset "AD for ELBO on NSF" begin |
| 125 | + @testset "$at" for at in [ |
| 126 | + ADTypes.AutoZygote(), |
| 127 | + ADTypes.AutoForwardDiff(), |
| 128 | + ADTypes.AutoReverseDiff(; compile=false), |
| 129 | + ADTypes.AutoEnzyme(; |
| 130 | + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), |
| 131 | + function_annotation=Enzyme.Const, |
| 132 | + ), |
| 133 | + # it doesn't work with mooncake yet |
| 134 | + ADTypes.AutoMooncake(; config=Mooncake.Config()), |
| 135 | + ] |
| 136 | + @testset "$T" for T in [Float32, Float64] |
| 137 | + μ = 10 * ones(T, 2) |
| 138 | + Σ = Diagonal(4 * ones(T, 2)) |
| 139 | + target = MvNormal(μ, Σ) |
| 140 | + logp(z) = logpdf(target, z) |
| 141 | + |
| 142 | + # necessary for Zygote/mooncake to differentiate through the flow |
| 143 | + # prevent updating params of q0 |
| 144 | + @leaf MvNormal |
| 145 | + q₀ = MvNormal(zeros(T, 2), ones(T, 2)) |
| 146 | + flow = realnvp(q₀, [8, 8], 3; paramtype=T) |
| 147 | + |
| 148 | + θ, re = Optimisers.destructure(flow) |
| 149 | + |
| 150 | + # check grad computation for elbo |
| 151 | + function loss(θ, rng, logp, sample_per_iter) |
| 152 | + return -NormalizingFlows.elbo_batch(rng, re(θ), logp, sample_per_iter) |
| 153 | + end |
| 154 | + |
| 155 | + rng = Random.default_rng() |
| 156 | + sample_per_iter = 10 |
| 157 | + |
| 158 | + prep = NormalizingFlows._prepare_gradient( |
| 159 | + loss, at, θ, rng, logp, sample_per_iter |
| 160 | + ) |
| 161 | + value, grad = NormalizingFlows._value_and_gradient( |
| 162 | + loss, prep, at, θ, rng, logp, sample_per_iter |
| 163 | + ) |
| 164 | + |
| 165 | + @test value !== nothing |
| 166 | + @test all(grad .!= nothing) |
| 167 | + end |
| 168 | + end |
| 169 | +end |
0 commit comments