Skip to content

Commit 7b4e3cd

Browse files
authored
Merge pull request #41 from TuringLang/update
Add AbstractMCMC 0.5 and update Turing
2 parents add2f5e + 62351c1 commit 7b4e3cd

21 files changed

+794
-584
lines changed

Project.toml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
33
authors = ["mohamed82008 <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -10,35 +10,36 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111

1212
[compat]
13-
AbstractMCMC = "0.4"
13+
AbstractMCMC = "0.4, 0.5"
1414
Bijectors = "0.5.2"
1515
Distributions = "0.22"
1616
MacroTools = "0.5.1"
1717
julia = "1"
1818

1919
[extras]
20-
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
2120
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
22-
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
23-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
21+
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
22+
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
2423
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
2524
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2625
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2726
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2827
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
28+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2929
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
30-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
3130
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
3231
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
33-
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
32+
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
3433
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3534
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3635
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3736
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3837
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
38+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3939
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
4040
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4141
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
42+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4243

4344
[targets]
44-
test = ["AbstractMCMC", "AdvancedHMC", "Bijectors", "Distributions", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "MCMCChains", "MacroTools", "Markdown", "PDMats", "ProgressMeter", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsFuns", "Test", "Tracker"]
45+
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]

test/Turing/Turing.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ module Turing
1010

1111
using Requires, Reexport, ForwardDiff
1212
using Bijectors, StatsFuns, SpecialFunctions
13-
using Statistics, LinearAlgebra, ProgressMeter
13+
using Statistics, LinearAlgebra
1414
using Markdown, Libtask, MacroTools
15-
using AbstractMCMC: sample, psample
1615
@reexport using Distributions, MCMCChains, Libtask
1716
using Tracker: Tracker
1817

@@ -21,7 +20,7 @@ import DynamicPPL: getspace, runmodel!
2120

2221
const PROGRESS = Ref(true)
2322
function turnprogress(switch::Bool)
24-
@info("[Turing]: global PROGRESS is set as $switch")
23+
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
2524
PROGRESS[] = switch
2625
end
2726

@@ -50,10 +49,14 @@ using .Variational
5049
# end
5150

5251
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin
53-
using Pkg;
54-
Pkg.installed()["DynamicHMC"] < v"2.0" && error("Please upgdate your DynamicHMC, v1.x is no longer supported")
55-
using ..Turing.DynamicHMC: DynamicHMC, mcmc_with_warmup
56-
include("contrib/inference/dynamichmc.jl")
52+
import ..DynamicHMC
53+
54+
if isdefined(DynamicHMC, :mcmc_with_warmup)
55+
using ..DynamicHMC: mcmc_with_warmup
56+
include("contrib/inference/dynamichmc.jl")
57+
else
58+
error("Please update DynamicHMC, v1.x is no longer supported")
59+
end
5760
end
5861

5962
###########
@@ -69,6 +72,7 @@ export @model, # modelling
6972
DynamicPPL,
7073

7174
MH, # classic sampling
75+
RWMH,
7276
ESS,
7377
Gibbs,
7478

@@ -109,4 +113,8 @@ export @model, # modelling
109113
LogPoisson,
110114
NamedDist
111115

116+
# Reexports
117+
using AbstractMCMC: sample, psample
118+
export sample, psample
119+
112120
end

test/Turing/contrib/inference/AdvancedSMCExtensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first:
255255

256256
# Run SMC & CSMC nodes
257257
for j in 1:spl.alg.n_nodes
258-
reset_num_produce!(VarInfos[j])
258+
VarInfos[j].num_produce = 0
259259
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1]
260260
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
261261
end

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using AbstractMCMC: NoCallback
2-
31
###
42
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
53
###
@@ -110,13 +108,41 @@ function Sampler(
110108
return Sampler(alg, Dict{Symbol,Any}(), s, state)
111109
end
112110

113-
# Disable the callback for DynamicHMC, since it has it's own progress meter.
114-
function AbstractMCMC.init_callback(
111+
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
112+
function AbstractMCMC.sample(
115113
rng::AbstractRNG,
116-
model::Model,
117-
s::Sampler{<:DynamicNUTS},
114+
model::AbstractModel,
115+
alg::DynamicNUTS,
118116
N::Integer;
117+
chain_type=Chains,
118+
resume_from=nothing,
119+
progress=PROGRESS[],
119120
kwargs...
120121
)
121-
return NoCallback()
122+
if progress
123+
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
124+
end
125+
if resume_from === nothing
126+
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
127+
chain_type=chain_type, progress=false, kwargs...)
128+
else
129+
return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...)
130+
end
122131
end
132+
133+
function AbstractMCMC.psample(
134+
rng::AbstractRNG,
135+
model::AbstractModel,
136+
alg::DynamicNUTS,
137+
N::Integer,
138+
n_chains::Integer;
139+
chain_type=Chains,
140+
progress=PROGRESS[],
141+
kwargs...
142+
)
143+
if progress
144+
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
145+
end
146+
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
147+
chain_type=chain_type, progress=false, kwargs...)
148+
end

test/Turing/core/Core.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Core
22

3-
using Bijectors
43
using MacroTools, Libtask, ForwardDiff, Random
54
using Distributions, LinearAlgebra
65
using ..Utilities, Reexport
@@ -14,14 +13,9 @@ import Bijectors: link, invlink
1413
using DistributionsAD
1514
using StatsFuns: logsumexp, softmax
1615
@reexport using DynamicPPL
17-
using Requires
1816

1917
include("container.jl")
2018
include("ad.jl")
21-
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
22-
include("compat/zygote.jl")
23-
export ZygoteAD
24-
end
2519

2620
export @model,
2721
@varname,

test/Turing/core/ad.jl

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
##############################
22
# Global variables/constants #
33
##############################
4+
using Bijectors
45

56
const ADBACKEND = Ref(:forward_diff)
6-
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
7-
function setadbackend(::Val{:forward_diff})
8-
CHUNKSIZE[] == 0 && setchunksize(40)
9-
ADBACKEND[] = :forward_diff
10-
end
11-
function setadbackend(::Val{:reverse_diff})
12-
ADBACKEND[] = :reverse_diff
7+
function setadbackend(backend_sym)
8+
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
9+
backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40)
10+
ADBACKEND[] = backend_sym
11+
12+
Bijectors.setadbackend(backend_sym)
1313
end
1414

1515
const ADSAFE = Ref(false)
@@ -39,8 +39,7 @@ ADBackend() = ADBackend(ADBACKEND[])
3939
ADBackend(T::Symbol) = ADBackend(Val(T))
4040

4141
ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]}
42-
ADBackend(::Val{:reverse_diff}) = TrackerAD
43-
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
42+
ADBackend(::Val) = TrackerAD
4443

4544
"""
4645
getADtype(alg)
@@ -70,8 +69,8 @@ function gradient_logp(
7069
ad_type = getADtype(sampler)
7170
if ad_type <: ForwardDiffAD
7271
return gradient_logp_forward(θ, vi, model, sampler)
73-
else
74-
return gradient_logp_reverse(ad_type(), θ, vi, model, sampler)
72+
else ad_type <: TrackerAD
73+
return gradient_logp_reverse(θ, vi, model, sampler)
7574
end
7675
end
7776

@@ -114,22 +113,20 @@ end
114113

115114
"""
116115
gradient_logp_reverse(
117-
backend::ADBackend,
118116
θ::AbstractVector{<:Real},
119117
vi::VarInfo,
120118
model::Model,
121-
sampler::AbstractSampler = SampleFromPrior(),
119+
sampler::AbstractSampler=SampleFromPrior(),
122120
)
123121
124122
Computes the value of the log joint of `θ` and its gradient for the model
125-
specified by `(vi, sampler, model)` using reverse-mode AD from the specified `backend`, e.g. `TrackerAD()` which uses `Tracker.jl` or `ZygoteAD()` which uses `Zygote.jl`.
123+
specified by `(vi, sampler, model)` using reverse-mode AD from Tracker.jl.
126124
"""
127125
function gradient_logp_reverse(
128-
backend::TrackerAD,
129126
θ::AbstractVector{<:Real},
130127
vi::VarInfo,
131128
model::Model,
132-
sampler::AbstractSampler = SampleFromPrior(),
129+
sampler::AbstractSampler=SampleFromPrior(),
133130
)
134131
T = typeof(getlogp(vi))
135132

@@ -141,19 +138,10 @@ function gradient_logp_reverse(
141138

142139
# Compute forward and reverse passes.
143140
l_tracked, ȳ = Tracker.forward(f, θ)
144-
# Remove tracking info from variables in model (because mutable state).
145141
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])
146-
142+
# Remove tracking info from variables in model (because mutable state).
147143
return l, ∂l∂θ
148144
end
149-
function gradient_logp_reverse(
150-
θ::AbstractVector{<:Real},
151-
vi::VarInfo,
152-
model::Model,
153-
sampler::AbstractSampler = SampleFromPrior(),
154-
)
155-
return gradient_logp_reverse(TrackerAD(), θ, vi, model, sampler)
156-
end
157145

158146
function verifygrad(grad::AbstractVector{<:Real})
159147
if any(isnan, grad) || any(isinf, grad)
@@ -165,7 +153,6 @@ function verifygrad(grad::AbstractVector{<:Real})
165153
end
166154
end
167155

168-
# Replace the adjoints below with Zygote ones
169156
for F in (:link, :invlink)
170157
@eval begin
171158
function $F(

test/Turing/core/compat/zygote.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/Turing/inference/AdvancedSMC.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function AbstractMCMC.step!(
109109
spl::Sampler{<:SMC},
110110
::Integer,
111111
transition;
112-
iteration = -1,
112+
iteration=-1,
113113
kwargs...
114114
)
115115
# check that we received a real iteration number
@@ -238,23 +238,22 @@ function AbstractMCMC.sample_end!(
238238
spl::Sampler{<:ParticleInference},
239239
N::Integer,
240240
ts::Vector{<:ParticleTransition};
241+
resume_from = nothing,
241242
kwargs...
242243
)
243-
# Set the default for resuming the sampler.
244-
resume_from = get(kwargs, :resume_from, nothing)
245-
246244
# Exponentiate the average log evidence.
247245
# loge = exp(mean([t.le for t in ts]))
248246
loge = mean(t.le for t in ts)
249247

250248
# If we already had a chain, grab the logevidence.
251-
if resume_from !== nothing # concat samples
252-
@assert resume_from isa Chains "resume_from needs to be a Chains object."
249+
if resume_from isa Chains
253250
# pushfirst!(samples, resume_from.info[:samples]...)
254251
pre_loge = resume_from.logevidence
255252
# Calculate new log-evidence
256253
pre_n = length(resume_from)
257254
loge = (pre_loge * pre_n + loge * N) / (pre_n + N)
255+
elseif resume_from !== nothing
256+
error("keyword argument `resume_from` has to be `nothing` or a `Chains` object")
258257
end
259258

260259
# Store the logevidence.

0 commit comments

Comments
 (0)