|
1 |
| -@testset "learining 2d Gaussian" begin |
| 1 | +@testset "testing mean-field Gaussian VI" begin |
2 | 2 | chunksize = 4
|
3 | 3 | @testset "$adtype" for adtype in [
|
4 | 4 | ADTypes.AutoZygote(),
|
|
48 | 48 | end
|
49 | 49 | end
|
50 | 50 | end
|
| 51 | + |
| 52 | +# function create_planar_flow(n_layers::Int, q₀, T) |
| 53 | +# d = length(q₀) |
| 54 | +# if T == Float32 |
| 55 | +# Ls = reduce(∘, [f32(PlanarLayer(d)) for _ in 1:n_layers]) |
| 56 | +# else |
| 57 | +# Ls = reduce(∘, [PlanarLayer(d) for _ in 1:n_layers]) |
| 58 | +# end |
| 59 | +# return Bijectors.transformed(q₀, Ls) |
| 60 | +# end |
| 61 | + |
| 62 | +# @testset "testing planar flow" begin |
| 63 | +# chunksize = 4 |
| 64 | +# @testset "$adtype" for adtype in [ |
| 65 | +# ADTypes.AutoZygote(), |
| 66 | +# ADTypes.AutoForwardDiff(; chunksize=chunksize), |
| 67 | +# ADTypes.AutoForwardDiff(), |
| 68 | +# ADTypes.AutoReverseDiff(), |
| 69 | +# ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), |
| 70 | +# # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 |
| 71 | +# ] |
| 72 | +# @testset "$T" for T in [Float32, Float64] |
| 73 | +# μ = 10 * ones(T, 2) |
| 74 | +# Σ = Diagonal(4 * ones(T, 2)) |
| 75 | + |
| 76 | +# target = MvNormal(μ, Σ) |
| 77 | +# logp(z) = logpdf(target, z) |
| 78 | + |
| 79 | +# @leaf MvNormal |
| 80 | +# q₀ = MvNormal(zeros(T, 2), ones(T, 2)) |
| 81 | +# nlayers = 10 |
| 82 | +# flow = create_planar_flow(nlayers, q₀, T) |
| 83 | + |
| 84 | +# sample_per_iter = 10 |
| 85 | +# cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) |
| 86 | +# checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 |
| 87 | +# flow_trained, stats, _, _ = train_flow( |
| 88 | +# elbo, |
| 89 | +# flow, |
| 90 | +# logp, |
| 91 | +# sample_per_iter; |
| 92 | +# max_iters=10_000, |
| 93 | +# optimiser=Optimisers.Adam(one(T)/100), |
| 94 | +# ADbackend=adtype, |
| 95 | +# show_progress=false, |
| 96 | +# callback=cb, |
| 97 | +# hasconverged=checkconv, |
| 98 | +# ) |
| 99 | +# θ, re = Optimisers.destructure(flow_trained) |
| 100 | + |
| 101 | +# el_untrained = elbo(flow, logp, 1000) |
| 102 | +# el_trained = elbo(flow_trained, logp, 1000) |
| 103 | + |
| 104 | +# @test el_trained > el_untrained |
| 105 | +# @test el_trained > -1 |
| 106 | +# end |
| 107 | +# end |
| 108 | +# end |
0 commit comments