Skip to content

Commit 1d87a1d

Browse files
authored
Fix MLE/MAP with Zygote and ReverseDiff (#1408)
* Fix AD for modes * Increment version number * Remove display calls * Move tests
1 parent ae24c28 commit 1d87a1d

File tree

5 files changed

+12
-9
lines changed

5 files changed

+12
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.14.2"
3+
version = "0.14.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/core/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ getADbackend(spl::SampleFromPrior) = ADBackend()()
6565
θ::AbstractVector{<:Real},
6666
vi::VarInfo,
6767
model::Model,
68-
sampler::AbstractSampler=SampleFromPrior(),
68+
sampler::AbstractSampler,
69+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6970
)
7071
7172
Computes the value of the log joint of `θ` and its gradient for the model
@@ -89,6 +90,7 @@ gradient_logp(
8990
vi::VarInfo,
9091
model::Model,
9192
sampler::AbstractSampler = SampleFromPrior(),
93+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
9294
)
9395
9496
Compute the value of the log joint of `θ` and its gradient for the model
@@ -160,7 +162,7 @@ function gradient_logp(
160162
# Specify objective function.
161163
function f(θ)
162164
new_vi = VarInfo(vi, sampler, θ)
163-
model(new_vi, sampler)
165+
model(new_vi, sampler, context)
164166
return getlogp(new_vi)
165167
end
166168

src/core/compat/reversediff.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function gradient_logp(
2727
# Specify objective function.
2828
function f(θ)
2929
new_vi = VarInfo(vi, sampler, θ)
30-
model(new_vi, sampler)
30+
model(new_vi, sampler, context)
3131
return getlogp(new_vi)
3232
end
3333
tp, result = taperesult(f, θ)
@@ -65,7 +65,7 @@ end
6565
# Specify objective function.
6666
function f(θ)
6767
new_vi = VarInfo(vi, sampler, θ)
68-
model(new_vi, sampler)
68+
model(new_vi, sampler, context)
6969
return getlogp(new_vi)
7070
end
7171
ctp, result = memoized_taperesult(f, θ)

test/modes/ModeEstimation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using NamedArrays
66
using ReverseDiff
77
using Random
88
using LinearAlgebra
9+
using Zygote
910

1011
dir = splitdir(splitdir(pathof(Turing))[1])[1]
1112
include(dir*"/test/test_utils/AllUtils.jl")

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ include("test_utils/AllUtils.jl")
4040
@testset "variational algorithms : $adbackend" begin
4141
include("variational/advi.jl")
4242
end
43+
44+
@testset "modes" begin
45+
include("modes/ModeEstimation.jl")
46+
end
4347
end
4448
@testset "variational optimisers" begin
4549
include("variational/optimisers.jl")
@@ -55,8 +59,4 @@ include("test_utils/AllUtils.jl")
5559
# include("utilities/stan-interface.jl")
5660
include("inference/utilities.jl")
5761
end
58-
59-
@testset "modes" begin
60-
include("modes/ModeEstimation.jl")
61-
end
6262
end

0 commit comments

Comments
 (0)