Skip to content

Commit 1970b09

Browse files
committed
fixing test bug
1 parent b7f9f08 commit 1970b09

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Requires = "1"
3737
ReverseDiff = "1.14"
3838
StatsBase = "0.33, 0.34"
3939
Zygote = "0.6, 0.7"
40-
julia = "1.6"
40+
julia = "1.10"
4141

4242
[extras]
4343
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

test/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
ADTypes.AutoReverseDiff(false),
1515
ADTypes.AutoMooncake(; config=Mooncake.Config()),
1616
]
17-
value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z)
17+
prep = NormalizingFlows._prepare_gradient(f, at, x, y, z)
18+
value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z)
1819
@test DiffResults.value(out) f(x, y, z)
1920
@test DiffResults.gradient(out) 2 * (x .+ y .+ z)
2021
end
@@ -42,8 +43,9 @@ end
4243

4344
# check grad computation for elbo
4445
loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...)
46+
prep = NormalizingFlows._prepare_gradient(loss, at, θ, logp, randn(T, 2, sample_per_iter))
4547
value, grad = NormalizingFlows._value_and_gradient(
46-
loss, at, θ, logp, randn(T, 2, sample_per_iter)
48+
loss, prep, at, θ, logp, randn(T, 2, sample_per_iter)
4749
)
4850

4951
@test !isnothing(value)

0 commit comments

Comments
 (0)