Skip to content

Commit f74e039

Browse files
authored
Merge pull request #114 from TuringLang/mt/update_turing_src
Update Turing src in tests
2 parents e2584c7 + fc1270d commit f74e039

File tree

18 files changed

+858
-392
lines changed

18 files changed

+858
-392
lines changed

test/Turing/Turing.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ module Turing
1111
using Requires, Reexport, ForwardDiff
1212
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
1313
using Statistics, LinearAlgebra
14-
using Markdown, Libtask, MacroTools
15-
@reexport using Distributions, MCMCChains, Libtask
14+
using Libtask
15+
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC
1616
using Tracker: Tracker
1717

18-
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
19-
import DynamicPPL: getspace
18+
import DynamicPPL: getspace, NoDist, NamedDist
2019

2120
const PROGRESS = Ref(true)
2221
function turnprogress(switch::Bool)
@@ -68,6 +67,8 @@ export @model, # modelling
6867
@varname,
6968
DynamicPPL,
7069

70+
Prior, # Sampling from the prior
71+
7172
MH, # classic sampling
7273
RWMH,
7374
ESS,
@@ -90,7 +91,6 @@ export @model, # modelling
9091
ADVI,
9192

9293
sample, # inference
93-
psample,
9494
setchunksize,
9595
resume,
9696
@logprob_str,
@@ -105,15 +105,10 @@ export @model, # modelling
105105
Flat,
106106
FlatPos,
107107
BinomialLogit,
108-
VecBinomialLogit,
108+
BernoulliLogit,
109109
OrderedLogistic,
110110
LogPoisson,
111111
NamedDist,
112112
filldist,
113113
arraydist
114-
115-
# Reexports
116-
using AbstractMCMC: sample, psample
117-
export sample, psample
118-
119114
end

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState
4646
draws::Vector{D}
4747
end
4848

49-
getspace(::DynamicNUTS{<:Any, space}) where {space} = space
49+
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
5050

5151
function AbstractMCMC.sample_init!(
5252
rng::AbstractRNG,
@@ -60,16 +60,23 @@ function AbstractMCMC.sample_init!(
6060
gradient_logp(x, spl.state.vi, model, spl)
6161
end
6262

63+
# Set the parameters to a starting value.
64+
initialize_parameters!(spl; kwargs...)
65+
6366
model(spl.state.vi, SampleFromUniform())
67+
link!(spl.state.vi, spl)
68+
l, dl = _lp(spl.state.vi[spl])
69+
while !isfinite(l) || !isfinite(dl)
70+
model(spl.state.vi, SampleFromUniform())
71+
link!(spl.state.vi, spl)
72+
l, dl = _lp(spl.state.vi[spl])
73+
end
6474

65-
if spl.selector.tag == :default
75+
if spl.selector.tag == :default && !islinked(spl.state.vi, spl)
6676
link!(spl.state.vi, spl)
6777
model(spl.state.vi, spl)
6878
end
6979

70-
# Set the parameters to a starting value.
71-
initialize_parameters!(spl; kwargs...)
72-
7380
results = mcmc_with_warmup(
7481
rng,
7582
FunctionLogDensity(
@@ -114,7 +121,7 @@ end
114121
model::AbstractModel,
115122
alg::DynamicNUTS,
116123
N::Integer;
117-
chain_type=Chains,
124+
chain_type=MCMCChains.Chains,
118125
resume_from=nothing,
119126
progress=PROGRESS[],
120127
kwargs...
@@ -130,19 +137,20 @@ end
130137
end
131138
end
132139

133-
function AbstractMCMC.psample(
140+
function AbstractMCMC.sample(
134141
rng::AbstractRNG,
135142
model::AbstractModel,
136143
alg::DynamicNUTS,
144+
parallel::AbstractMCMC.AbstractMCMCParallel,
137145
N::Integer,
138146
n_chains::Integer;
139-
chain_type=Chains,
147+
chain_type=MCMCChains.Chains,
140148
progress=PROGRESS[],
141149
kwargs...
142150
)
143151
if progress
144152
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
145153
end
146-
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
147-
chain_type=chain_type, progress=false, kwargs...)
154+
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
155+
chain_type=chain_type, progress=false, kwargs...)
148156
end

test/Turing/contrib/inference/sghmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function step(
172172
spl.selector.tag != :default && link!(vi, spl)
173173

174174
mssa = AHMC.Adaptation.ManualSSAdaptor(AHMC.Adaptation.MSSState(spl.alg.ϵ))
175-
spl.info[:adaptor] = AHMC.NaiveHMCAdaptor(AHMC.UnitPreconditioner(), mssa)
175+
spl.info[:adaptor] = AHMC.NaiveHMCAdaptor(AHMC.UnitMassMatrix(), mssa)
176176

177177
spl.selector.tag != :default && invlink!(vi, spl)
178178
return vi, true

test/Turing/core/Core.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
module Core
22

33
using DistributionsAD, Bijectors
4-
using MacroTools, Libtask, ForwardDiff, Random
4+
using Libtask, ForwardDiff, Random
55
using Distributions, LinearAlgebra
66
using ..Utilities, Reexport
77
using Tracker: Tracker
88
using ..Turing: Turing
9-
using DynamicPPL: Model,
10-
AbstractSampler, Sampler, SampleFromPrior
9+
using DynamicPPL: Model, AbstractSampler, Sampler, SampleFromPrior
1110
using LinearAlgebra: copytri!
1211
using Bijectors: PDMatDistribution
1312
import Bijectors: link, invlink
@@ -17,9 +16,15 @@ using Requires
1716

1817
include("container.jl")
1918
include("ad.jl")
20-
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
21-
include("compat/zygote.jl")
22-
export ZygoteAD
19+
function __init__()
20+
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
21+
include("compat/zygote.jl")
22+
export ZygoteAD
23+
end
24+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
25+
include("compat/reversediff.jl")
26+
export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache
27+
end
2328
end
2429

2530
export @model,
@@ -36,10 +41,9 @@ export @model,
3641
forkr,
3742
current_trace,
3843
getweights,
44+
getweight,
3945
effectiveSampleSize,
40-
increase_logweight,
41-
inrease_logevidence,
42-
resample!,
46+
sweep!,
4347
ResampleWithESSThreshold,
4448
ADBackend,
4549
setadbackend,

test/Turing/core/ad.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
##############################
22
# Global variables/constants #
33
##############################
4-
const ADBACKEND = Ref(:forward_diff)
4+
const ADBACKEND = Ref(:forwarddiff)
55
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
66
function setadbackend(::Val{:forward_diff})
7+
Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
8+
setadbackend(Val(:forwarddiff))
9+
end
10+
function setadbackend(::Val{:forwarddiff})
711
CHUNKSIZE[] == 0 && setchunksize(40)
8-
ADBACKEND[] = :forward_diff
12+
ADBACKEND[] = :forwarddiff
913
end
14+
1015
function setadbackend(::Val{:reverse_diff})
11-
ADBACKEND[] = :reverse_diff
16+
Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend)
17+
setadbackend(Val(:tracker))
18+
end
19+
function setadbackend(::Val{:tracker})
20+
ADBACKEND[] = :tracker
1221
end
1322

1423
const ADSAFE = Ref(false)
@@ -37,8 +46,8 @@ struct TrackerAD <: ADBackend end
3746
ADBackend() = ADBackend(ADBACKEND[])
3847
ADBackend(T::Symbol) = ADBackend(Val(T))
3948

40-
ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]}
41-
ADBackend(::Val{:reverse_diff}) = TrackerAD
49+
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
50+
ADBackend(::Val{:tracker}) = TrackerAD
4251
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
4352

4453
"""
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using .ReverseDiff: compile, GradientTape
2+
using .ReverseDiff.DiffResults: GradientResult
3+
4+
struct ReverseDiffAD{cache} <: ADBackend end
5+
const RDCache = Ref(false)
6+
setrdcache(b::Bool) = setrdcache(Val(b))
7+
setrdcache(::Val{false}) = RDCache[] = false
8+
setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.")
9+
function emptyrdcache end
10+
11+
getrdcache() = RDCache[]
12+
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}
13+
function setadbackend(::Val{:reversediff})
14+
ADBACKEND[] = :reversediff
15+
end
16+
17+
function gradient_logp(
18+
backend::ReverseDiffAD{false},
19+
θ::AbstractVector{<:Real},
20+
vi::VarInfo,
21+
model::Model,
22+
sampler::AbstractSampler = SampleFromPrior(),
23+
)
24+
T = typeof(getlogp(vi))
25+
26+
# Specify objective function.
27+
function f(θ)
28+
new_vi = VarInfo(vi, sampler, θ)
29+
model(new_vi, sampler)
30+
return getlogp(new_vi)
31+
end
32+
tp, result = taperesult(f, θ)
33+
ReverseDiff.gradient!(result, tp, θ)
34+
l = DiffResults.value(result)
35+
∂l∂θ::typeof(θ) = DiffResults.gradient(result)
36+
37+
return l, ∂l∂θ
38+
end
39+
40+
tape(f, x) = GradientTape(f, x)
41+
function taperesult(f, x)
42+
return tape(f, x), GradientResult(x)
43+
end
44+
45+
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
46+
setrdcache(::Val{true}) = RDCache[] = true
47+
function emptyrdcache()
48+
for k in keys(Memoization.caches)
49+
if k[1] === typeof(memoized_taperesult)
50+
pop!(Memoization.caches, k)
51+
end
52+
end
53+
end
54+
function gradient_logp(
55+
backend::ReverseDiffAD{true},
56+
θ::AbstractVector{<:Real},
57+
vi::VarInfo,
58+
model::Model,
59+
sampler::AbstractSampler = SampleFromPrior(),
60+
)
61+
T = typeof(getlogp(vi))
62+
63+
# Specify objective function.
64+
function f(θ)
65+
new_vi = VarInfo(vi, sampler, θ)
66+
model(new_vi, sampler)
67+
return getlogp(new_vi)
68+
end
69+
ctp, result = memoized_taperesult(f, θ)
70+
ReverseDiff.gradient!(result, ctp, θ)
71+
l = DiffResults.value(result)
72+
∂l∂θ = DiffResults.gradient(result)
73+
74+
return l, ∂l∂θ
75+
end
76+
77+
# This makes sure we generate a single tape per Turing model and sampler
78+
struct RDTapeKey{F, Tx}
79+
f::F
80+
x::Tx
81+
end
82+
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
83+
key = keys[1][1]
84+
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
85+
end
86+
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
87+
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
88+
return compiledtape(k.f, k.x), GradientResult(k.x)
89+
end
90+
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
91+
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
92+
compiledtape(f, x) = compile(GradientTape(f, x))
93+
end

0 commit comments

Comments
 (0)