Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.12.0"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand Down
11 changes: 6 additions & 5 deletions src/contrib/inference/AdvancedSMCExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), ())

function Sampler(alg::PMMH, model::Model, s::Selector)
function Sampler(alg::PMMH, model::Model, s::Selector; specialize_after=1)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)

Expand Down Expand Up @@ -118,7 +118,8 @@ function sample( model::Model,
alg::PMMH;
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
reuse_spl_n=0 # flag for spl re-using
reuse_spl_n=0, # flag for spl re-using
specialize_after=1
)

spl = Sampler(alg, model)
Expand All @@ -140,7 +141,7 @@ function sample( model::Model,

# Init parameters
vi = if resume_from === nothing
vi_ = VarInfo(model)
vi_ = VarInfo(model, specialize_after)
else
resume_from.info[:vi]
end
Expand Down Expand Up @@ -279,7 +280,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first:
VarInfos[nodes_permutation]
end

function sample(model::Model, alg::IPMCMC)
function sample(model::Model, alg::IPMCMC; specialize_after=1)

spl = Sampler(alg)

Expand All @@ -295,7 +296,7 @@ function sample(model::Model, alg::IPMCMC)
end

# Init parameters
vi = empty!(VarInfo(model))
vi = empty!(VarInfo(model, specialize_after))
VarInfos = Array{VarInfo}(undef, spl.alg.n_nodes)
for j in 1:spl.alg.n_nodes
VarInfos[j] = deepcopy(vi)
Expand Down
34 changes: 11 additions & 23 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function DynamicNUTS{AD}(space::Symbol...) where AD
DynamicNUTS{AD, space}()
end

mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState
mutable struct DynamicNUTSState{V<:AbstractVarInfo, D} <: AbstractSamplerState
vi::V
draws::Vector{D}
end
Expand All @@ -53,29 +53,16 @@ function AbstractMCMC.sample_init!(
model::Model,
spl::Sampler{<:DynamicNUTS},
N::Integer;
kwargs...
kwargs...,
)
# Set up lp function.
function _lp(x)
gradient_logp(x, spl.state.vi, model, spl)
gradient_logp(x, link(spl.state.vi), model, spl)
end

# Set the parameters to a starting value.
initialize_parameters!(spl; kwargs...)

model(spl.state.vi, SampleFromUniform())
link!(spl.state.vi, spl)
l, dl = _lp(spl.state.vi[spl])
while !isfinite(l) || !isfinite(dl)
model(spl.state.vi, SampleFromUniform())
link!(spl.state.vi, spl)
l, dl = _lp(spl.state.vi[spl])
end

if spl.selector.tag == :default && !islinked(spl.state.vi, spl)
link!(spl.state.vi, spl)
model(spl.state.vi, spl)
end
link!(spl.state.vi, spl, model)

results = mcmc_with_warmup(
rng,
Expand All @@ -99,7 +86,8 @@ function AbstractMCMC.step!(
)
# Pop the next draw off the vector.
draw = popfirst!(spl.state.draws)
spl.state.vi[spl] = draw
link(spl.state.vi)[spl] = draw
invlink!(spl.state.vi, spl, model)
return Transition(spl)
end

Expand All @@ -109,7 +97,7 @@ function Sampler(
s::Selector=Selector()
)
# Construct a state, using a default function.
state = DynamicNUTSState(VarInfo(model), [])
state = DynamicNUTSState(DynamicPPL.TypedVarInfo(model), [])

# Return a new sampler.
return Sampler(alg, Dict{Symbol,Any}(), s, state)
Expand All @@ -118,7 +106,7 @@ end
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
model::DynamicPPL.AbstractModel,
alg::DynamicNUTS,
N::Integer;
chain_type=MCMCChains.Chains,
Expand All @@ -127,7 +115,7 @@ end
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
@warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
Expand All @@ -139,7 +127,7 @@ end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
model::DynamicPPL.AbstractModel,
alg::DynamicNUTS,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
Expand All @@ -149,7 +137,7 @@ function AbstractMCMC.sample(
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
@warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
chain_type=chain_type, progress=false, kwargs...)
Expand Down
11 changes: 3 additions & 8 deletions src/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ function step(
is_first::Val{true};
kwargs...
)
spl.selector.tag != :default && link!(vi, spl)

# Initialize velocity
v = zeros(Float64, size(vi[spl]))
spl.info[:v] = v

spl.selector.tag != :default && invlink!(vi, spl)
return vi, true
end

Expand All @@ -84,13 +80,12 @@ function step(

Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag != :default
link!(vi, spl)
model(vi, spl)
model(initlink(vi), spl)
end

Turing.DEBUG && @debug "recording old variables..."
θ, v = vi[spl], spl.info[:v]
_, grad = gradient_logp(θ, vi, model, spl)
_, grad = gradient_logp(θ, link(vi), model, spl)
verifygrad(grad)

# Implements the update equations from (15) of Chen et al. (2014).
Expand Down Expand Up @@ -197,7 +192,7 @@ function step(

Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag != :default
link!(vi, spl)
link!(vi, spl, model)
model(vi, spl)
end

Expand Down
10 changes: 5 additions & 5 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ getADbackend(spl::Sampler) = getADbackend(spl.alg)
"""
gradient_logp(
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
)
Expand All @@ -71,7 +71,7 @@ tool is currently active.
"""
function gradient_logp(
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::Sampler
)
Expand All @@ -82,7 +82,7 @@ end
gradient_logp(
backend::ADBackend,
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
Expand All @@ -93,7 +93,7 @@ specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{
function gradient_logp(
::ForwardDiffAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
)
Expand All @@ -120,7 +120,7 @@ end
function gradient_logp(
::TrackerAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
Expand Down
4 changes: 2 additions & 2 deletions src/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
function gradient_logp(
backend::ReverseDiffAD{false},
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
Expand Down Expand Up @@ -54,7 +54,7 @@ end
function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
Expand Down
2 changes: 1 addition & 1 deletion src/core/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
function gradient_logp(
backend::ZygoteAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
vi::AbstractVarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
Expand Down
63 changes: 46 additions & 17 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@ struct ParticleTransition{T, F<:AbstractFloat}
weight::F
end

function Base.promote_type(
::Type{ParticleTransition{T1, F1}},
::Type{ParticleTransition{T2, F2}},
) where {T1, F1, T2, F2}
return ParticleTransition{
Union{T1, T2},
promote_type(F1, F2),
}
end
function Base.convert(
::Type{ParticleTransition{T, F}},
t::ParticleTransition,
) where {T, F}
return ParticleTransition{T, F}(
convert(T, t.θ),
convert(F, t.lp),
convert(F, t.le),
convert(F, t.weight),
)
end

function additional_parameters(::Type{<:ParticleTransition})
return [:lp,:le, :weight]
end
Expand Down Expand Up @@ -69,23 +90,23 @@ SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, sp
SMC(space::Symbol...) = SMC(space)
SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space)

mutable struct SMCState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState
mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState
vi :: V
# The logevidence after aggregating all samples together.
average_logevidence :: F
particles :: ParticleContainer
end

function SMCState(model::Model)
vi = VarInfo(model)
function SMCState(model::Model; specialize_after=1)
vi = VarInfo(model, specialize_after)
particles = ParticleContainer(Trace[])

return SMCState(vi, 0.0, particles)
end

function Sampler(alg::SMC, model::Model, s::Selector)
function Sampler(alg::SMC, model::Model, s::Selector; specialize_after=1)
dict = Dict{Symbol, Any}()
state = SMCState(model)
state = SMCState(model; specialize_after=specialize_after)
return Sampler(alg, dict, s, state)
end

Expand Down Expand Up @@ -203,27 +224,31 @@ function PG(nparticles::Int, space::Tuple)
return PG(nparticles, Turing.Core.ResampleWithESSThreshold(), space)
end

mutable struct PGState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState
mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState
vi :: V
# The logevidence after aggregating all samples together.
average_logevidence :: F
end

function PGState(model::Model)
vi = VarInfo(model)
function PGState(model::Model; specialize_after=1)
vi = VarInfo(model, specialize_after)
return PGState(vi, 0.0)
end

function replace_varinfo(s::PGState, vi::AbstractVarInfo)
return PGState(vi, s.average_logevidence)
end

const CSMC = PG # type alias of PG as Conditional SMC

"""
Sampler(alg::PG, model::Model, s::Selector)

Return a `Sampler` object for the PG algorithm.
"""
function Sampler(alg::PG, model::Model, s::Selector)
function Sampler(alg::PG, model::Model, s::Selector; specialize_after=1)
info = Dict{Symbol, Any}()
state = PGState(model)
state = PGState(model; specialize_after=specialize_after)
return Sampler(alg, info, s, state)
end

Expand Down Expand Up @@ -319,23 +344,27 @@ function DynamicPPL.assume(
r = rand(dist)
push!(vi, vn, r, dist, spl)
elseif is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = rand(dist)
vi[vn] = vectorize(dist, r)
setgid!(vi, spl.selector, vn)
setorder!(vi, vn, get_num_produce(vi))
if length(vi[vn, dist]) == length(r)
vi[vn, dist] = r
unset_flag!(vi, vn, "del")
updategid!(vi, vn, spl)
else
DynamicPPL.removedel!(vi)
push!(vi, vn, r, dist, spl)
end
else
updategid!(vi, vn, spl)
r = vi[vn]
r = vi[vn, dist]
end
else # vn belongs to other sampler <=> conditionning on vn
if haskey(vi, vn)
r = vi[vn]
r = vi[vn, dist]
else
r = rand(dist)
push!(vi, vn, r, dist, Selector(:invalid))
end
lp = logpdf_with_trans(dist, r, istrans(vi, vn))
lp = logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))
acclogp!(vi, lp)
end
return r, 0
Expand Down
Loading