|
73 | 73 | end |
74 | 74 | end |
75 | 75 | end |
| 76 | + |
| 77 | + |
| 78 | +@testset "AD for ELBO on realnvp" begin |
| 79 | + @testset "$at" for at in [ |
| 80 | + ADTypes.AutoZygote(), |
| 81 | + ADTypes.AutoForwardDiff(), |
| 82 | + ADTypes.AutoReverseDiff(; compile=false), |
| 83 | + ADTypes.AutoEnzyme(; |
| 84 | + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), |
| 85 | + function_annotation=Enzyme.Const, |
| 86 | + ), |
| 87 | + ADTypes.AutoMooncake(; config=Mooncake.Config()), |
| 88 | + ] |
| 89 | + @testset "$T" for T in [Float32, Float64] |
| 90 | + μ = 10 * ones(T, 2) |
| 91 | + Σ = Diagonal(4 * ones(T, 2)) |
| 92 | + target = MvNormal(μ, Σ) |
| 93 | + logp(z) = logpdf(target, z) |
| 94 | + |
| 95 | + # necessary for Zygote/mooncake to differentiate through the flow |
| 96 | + # prevent updating params of q0 |
| 97 | + @leaf MvNormal |
| 98 | + q₀ = MvNormal(zeros(T, 2), ones(T, 2)) |
| 99 | + flow = realnvp(q₀, [8, 8], 3; paramtype=T) |
| 100 | + |
| 101 | + θ, re = Optimisers.destructure(flow) |
| 102 | + |
| 103 | + # check grad computation for elbo |
| 104 | + function loss(θ, rng, logp, sample_per_iter) |
| 105 | + return -NormalizingFlows.elbo_batch(rng, re(θ), logp, sample_per_iter) |
| 106 | + end |
| 107 | + |
| 108 | + rng = Random.default_rng() |
| 109 | + sample_per_iter = 10 |
| 110 | + |
| 111 | + prep = NormalizingFlows._prepare_gradient( |
| 112 | + loss, at, θ, rng, logp, sample_per_iter |
| 113 | + ) |
| 114 | + value, grad = NormalizingFlows._value_and_gradient( |
| 115 | + loss, prep, at, θ, rng, logp, sample_per_iter |
| 116 | + ) |
| 117 | + |
| 118 | + @test value !== nothing |
| 119 | + @test all(grad .!= nothing) |
| 120 | + end |
| 121 | + end |
| 122 | +end |
0 commit comments