Skip to content

Commit 96a79f3

Browse files
authored
Fix memoization issue (#1414)
1 parent e6430f1 commit 96a79f3

File tree

4 files changed

+36
-21
lines changed

4 files changed

+36
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.14.3"
3+
version = "0.14.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/core/compat/reversediff.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function gradient_logp(
2323
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
2424
)
2525
T = typeof(getlogp(vi))
26-
26+
2727
# Specify objective function.
2828
function f(θ)
2929
new_vi = VarInfo(vi, sampler, θ)
@@ -46,12 +46,10 @@ end
4646
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
4747
setrdcache(::Val{true}) = RDCache[] = true
4848
function emptyrdcache()
49-
for k in keys(Memoization.caches)
50-
if k[1] === typeof(memoized_taperesult)
51-
pop!(Memoization.caches, k)
52-
end
53-
end
49+
Memoization.empty_cache!(memoized_taperesult)
50+
return
5451
end
52+
5553
function gradient_logp(
5654
backend::ReverseDiffAD{true},
5755
θ::AbstractVector{<:Real},
@@ -61,7 +59,7 @@ end
6159
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6260
)
6361
T = typeof(getlogp(vi))
64-
62+
6563
# Specify objective function.
6664
function f(θ)
6765
new_vi = VarInfo(vi, sampler, θ)
@@ -81,15 +79,13 @@ end
8179
f::F
8280
x::Tx
8381
end
84-
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
82+
function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any})
8583
key = keys[1][1]
86-
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
84+
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid()))
8785
end
8886
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
89-
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
87+
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
9088
return compiledtape(k.f, k.x), GradientResult(k.x)
9189
end
92-
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
93-
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
9490
compiledtape(f, x) = compile(GradientTape(f, x))
9591
end

test/core/ad.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,13 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
276276
sample(dir(), HMC(0.01, 1), 1000);
277277
Turing.setrdcache(true)
278278
sample(dir(), HMC(0.01, 1), 1000);
279-
@test length(Memoization.caches) == 1
279+
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
280+
@test length(caches) == 1
281+
@test !isempty(first(values(caches)))
280282
Turing.emptyrdcache()
281-
@test length(Memoization.caches) == 0
283+
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
284+
@test length(caches) == 1
285+
@test isempty(first(values(caches)))
282286
end
283287
# FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff
284288
@testset "PDMatDistribution AD" begin
@@ -340,4 +344,24 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
340344
@test H_f == [1.0 0.0; 0.0 1.0]
341345
@test H_f == H_r
342346
end
347+
348+
@testset "memoization: issue #1393" begin
349+
Turing.setadbackend(:reversediff)
350+
Turing.setrdcache(true)
351+
352+
@model function demo(data)
353+
sigma ~ Uniform(0.0, 20.0)
354+
data ~ Normal(0, sigma)
355+
end
356+
357+
N = 1000
358+
for i in 1:5
359+
d = Normal(0.0, i)
360+
data = rand(d, N)
361+
chn = sample(demo(data), NUTS(0.65), 1000)
362+
@test mean(Array(chn[:sigma])) std(data) atol=0.5
363+
end
364+
365+
Turing.emptyrdcache()
366+
end
343367
end

test/runtests.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@ include("test_utils/AllUtils.jl")
1515
include("core/container.jl")
1616
end
1717

18-
test_adbackends = if VERSION >= v"1.2"
19-
[:forwarddiff, :tracker, :reversediff]
20-
else
21-
[:forwarddiff, :tracker]
22-
end
2318
Turing.setrdcache(false)
24-
for adbackend in test_adbackends
19+
for adbackend in (:forwarddiff, :tracker, :reversediff)
2520
Turing.setadbackend(adbackend)
2621
@testset "inference: $adbackend" begin
2722
@testset "samplers" begin

0 commit comments

Comments
 (0)