|
1 | 1 | @testset "DI.AD with context wrapper" begin
|
2 | 2 | f(x, y, z) = sum(abs2, x .+ y .+ z)
|
| 3 | + T = Float32 |
3 | 4 |
|
4 | 5 | @testset "$T" for T in [Float32, Float64]
|
5 | 6 | x = randn(T, 10)
|
|
11 | 12 | ADTypes.AutoZygote(),
|
12 | 13 | ADTypes.AutoForwardDiff(; chunksize=chunksize),
|
13 | 14 | ADTypes.AutoForwardDiff(),
|
14 |
| - ADTypes.AutoReverseDiff(false), |
| 15 | + ADTypes.AutoReverseDiff(; false), |
15 | 16 | ADTypes.AutoMooncake(; config=Mooncake.Config()),
|
16 | 17 | ]
|
| 18 | + at = ADTypes.AutoMooncake(; config=Mooncake.Config()) |
17 | 19 | prep = NormalizingFlows._prepare_gradient(f, at, x, y, z)
|
18 | 20 | value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z)
|
19 | 21 | @test DiffResults.value(out) ≈ f(x, y, z)
|
|
26 | 28 | @testset "$at" for at in [
|
27 | 29 | ADTypes.AutoZygote(),
|
28 | 30 | ADTypes.AutoForwardDiff(),
|
29 |
| - ADTypes.AutoReverseDiff(false), |
| 31 | + ADTypes.AutoReverseDiff(; false), |
30 | 32 | ADTypes.AutoMooncake(; config=Mooncake.Config()),
|
31 | 33 | ]
|
32 | 34 | @testset "$T" for T in [Float32, Float64]
|
33 | 35 | μ = 10 * ones(T, 2)
|
34 | 36 | Σ = Diagonal(4 * ones(T, 2))
|
35 | 37 | target = MvNormal(μ, Σ)
|
36 | 38 | logp(z) = logpdf(target, z)
|
37 |
| - |
| 39 | + |
| 40 | + # necessary for Zygote/mooncake to differentiate through the flow |
| 41 | + # prevent opt q0 |
| 42 | + @leaf MvNormal |
38 | 43 | q₀ = MvNormal(zeros(T, 2), ones(T, 2))
|
39 | 44 | flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ)))
|
40 | 45 |
|
|
0 commit comments