|
14 | 14 | ADTypes.AutoReverseDiff(false),
|
15 | 15 | ADTypes.AutoMooncake(; config=Mooncake.Config()),
|
16 | 16 | ]
|
17 |
| - value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z) |
| 17 | + prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) |
| 18 | + value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) |
18 | 19 | @test DiffResults.value(out) ≈ f(x, y, z)
|
19 | 20 | @test DiffResults.gradient(out) ≈ 2 * (x .+ y .+ z)
|
20 | 21 | end
|
|
42 | 43 |
|
43 | 44 | # check grad computation for elbo
|
44 | 45 | loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...)
|
| 46 | + prep = NormalizingFlows._prepare_gradient(loss, at, θ, logp, randn(T, 2, sample_per_iter)) |
45 | 47 | value, grad = NormalizingFlows._value_and_gradient(
|
46 |
| - loss, at, θ, logp, randn(T, 2, sample_per_iter) |
| 48 | + loss, prep, at, θ, logp, randn(T, 2, sample_per_iter) |
47 | 49 | )
|
48 | 50 |
|
49 | 51 | @test !isnothing(value)
|
|
0 commit comments