diff --git a/Project.toml b/Project.toml index 716609900c..eca8770db9 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.0" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/src/contrib/inference/AdvancedSMCExtensions.jl b/src/contrib/inference/AdvancedSMCExtensions.jl index 7adc20eacd..8516d99d5c 100644 --- a/src/contrib/inference/AdvancedSMCExtensions.jl +++ b/src/contrib/inference/AdvancedSMCExtensions.jl @@ -39,7 +39,7 @@ end PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), ()) -function Sampler(alg::PMMH, model::Model, s::Selector) +function Sampler(alg::PMMH, model::Model, s::Selector; specialize_after=1) info = Dict{Symbol, Any}() spl = Sampler(alg, info, s) @@ -118,7 +118,8 @@ function sample( model::Model, alg::PMMH; save_state=false, # flag for state saving resume_from=nothing, # chain to continue - reuse_spl_n=0 # flag for spl re-using + reuse_spl_n=0, # flag for spl re-using + specialize_after=1 ) spl = Sampler(alg, model) @@ -140,7 +141,7 @@ function sample( model::Model, # Init parameters vi = if resume_from === nothing - vi_ = VarInfo(model) + vi_ = VarInfo(model, specialize_after) else resume_from.info[:vi] end @@ -279,7 +280,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first: VarInfos[nodes_permutation] end -function sample(model::Model, alg::IPMCMC) +function sample(model::Model, alg::IPMCMC; specialize_after=1) spl = Sampler(alg) @@ -295,7 +296,7 @@ function sample(model::Model, alg::IPMCMC) end # Init parameters - vi = empty!(VarInfo(model)) + vi = empty!(VarInfo(model, specialize_after)) VarInfos = Array{VarInfo}(undef, spl.alg.n_nodes) for j in 1:spl.alg.n_nodes VarInfos[j] = deepcopy(vi) diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 17d1221d9e..0cdac493b5 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -41,7 +41,7 @@ function DynamicNUTS{AD}(space::Symbol...) where AD DynamicNUTS{AD, space}() end -mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState +mutable struct DynamicNUTSState{V<:AbstractVarInfo, D} <: AbstractSamplerState vi::V draws::Vector{D} end @@ -53,29 +53,16 @@ function AbstractMCMC.sample_init!( model::Model, spl::Sampler{<:DynamicNUTS}, N::Integer; - kwargs... + kwargs..., ) # Set up lp function. function _lp(x) - gradient_logp(x, spl.state.vi, model, spl) + gradient_logp(x, link(spl.state.vi), model, spl) end # Set the parameters to a starting value. initialize_parameters!(spl; kwargs...) - - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - while !isfinite(l) || !isfinite(dl) - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - end - - if spl.selector.tag == :default && !islinked(spl.state.vi, spl) - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - end + link!(spl.state.vi, spl, model) results = mcmc_with_warmup( rng, @@ -99,7 +86,8 @@ function AbstractMCMC.step!( ) # Pop the next draw off the vector. draw = popfirst!(spl.state.draws) - spl.state.vi[spl] = draw + link(spl.state.vi)[spl] = draw + invlink!(spl.state.vi, spl, model) return Transition(spl) end @@ -109,7 +97,7 @@ function Sampler( s::Selector=Selector() ) # Construct a state, using a default function. - state = DynamicNUTSState(VarInfo(model), []) + state = DynamicNUTSState(DynamicPPL.TypedVarInfo(model), []) # Return a new sampler. return Sampler(alg, Dict{Symbol,Any}(), s, state) @@ -118,7 +106,7 @@ end # Disable the progress logging for DynamicHMC, since it has its own progress meter. function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, N::Integer; chain_type=MCMCChains.Chains, @@ -127,7 +115,7 @@ end kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end if resume_from === nothing return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; @@ -139,7 +127,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, @@ -149,7 +137,7 @@ function AbstractMCMC.sample( kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; chain_type=chain_type, progress=false, kwargs...) diff --git a/src/contrib/inference/sghmc.jl b/src/contrib/inference/sghmc.jl index 83c4886132..b1f9dda050 100644 --- a/src/contrib/inference/sghmc.jl +++ b/src/contrib/inference/sghmc.jl @@ -61,13 +61,9 @@ function step( is_first::Val{true}; kwargs... ) - spl.selector.tag != :default && link!(vi, spl) - # Initialize velocity v = zeros(Float64, size(vi[spl])) spl.info[:v] = v - - spl.selector.tag != :default && invlink!(vi, spl) return vi, true end @@ -84,13 +80,12 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) - model(vi, spl) + model(initlink(vi), spl) end Turing.DEBUG && @debug "recording old variables..." θ, v = vi[spl], spl.info[:v] - _, grad = gradient_logp(θ, vi, model, spl) + _, grad = gradient_logp(θ, link(vi), model, spl) verifygrad(grad) # Implements the update equations from (15) of Chen et al. (2014). @@ -197,7 +192,7 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) + link!(vi, spl, model) model(vi, spl) end diff --git a/src/core/ad.jl b/src/core/ad.jl index b896fccdda..7d86f76afa 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -60,7 +60,7 @@ getADbackend(spl::Sampler) = getADbackend(spl.alg) """ gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -71,7 +71,7 @@ tool is currently active. """ function gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::Sampler ) @@ -82,7 +82,7 @@ end gradient_logp( backend::ADBackend, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -93,7 +93,7 @@ specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{ function gradient_logp( ::ForwardDiffAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -120,7 +120,7 @@ end function gradient_logp( ::TrackerAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/core/compat/reversediff.jl b/src/core/compat/reversediff.jl index 5ccfbbebd5..f3822d35ee 100644 --- a/src/core/compat/reversediff.jl +++ b/src/core/compat/reversediff.jl @@ -17,7 +17,7 @@ end function gradient_logp( backend::ReverseDiffAD{false}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -54,7 +54,7 @@ end function gradient_logp( backend::ReverseDiffAD{true}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/core/compat/zygote.jl b/src/core/compat/zygote.jl index 3c56a1922c..dc18fa0f8f 100644 --- a/src/core/compat/zygote.jl +++ b/src/core/compat/zygote.jl @@ -7,7 +7,7 @@ end function gradient_logp( backend::ZygoteAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 77a4c70909..ff4c2ee1bf 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -22,6 +22,27 @@ struct ParticleTransition{T, F<:AbstractFloat} weight::F end +function Base.promote_type( + ::Type{ParticleTransition{T1, F1}}, + ::Type{ParticleTransition{T2, F2}}, +) where {T1, F1, T2, F2} + return ParticleTransition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{ParticleTransition{T, F}}, + t::ParticleTransition, +) where {T, F} + return ParticleTransition{T, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(F, t.le), + convert(F, t.weight), + ) +end + function additional_parameters(::Type{<:ParticleTransition}) return [:lp,:le, :weight] end @@ -69,23 +90,23 @@ SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, sp SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space) -mutable struct SMCState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F particles :: ParticleContainer end -function SMCState(model::Model) - vi = VarInfo(model) +function SMCState(model::Model; specialize_after=1) + vi = VarInfo(model, specialize_after) particles = ParticleContainer(Trace[]) return SMCState(vi, 0.0, particles) end -function Sampler(alg::SMC, model::Model, s::Selector) +function Sampler(alg::SMC, model::Model, s::Selector; specialize_after=1) dict = Dict{Symbol, Any}() - state = SMCState(model) + state = SMCState(model; specialize_after=specialize_after) return Sampler(alg, dict, s, state) end @@ -203,17 +224,21 @@ function PG(nparticles::Int, space::Tuple) return PG(nparticles, Turing.Core.ResampleWithESSThreshold(), space) end -mutable struct PGState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F end -function PGState(model::Model) - vi = VarInfo(model) +function PGState(model::Model; specialize_after=1) + vi = VarInfo(model, specialize_after) return PGState(vi, 0.0) end +function replace_varinfo(s::PGState, vi::AbstractVarInfo) + return PGState(vi, s.average_logevidence) +end + const CSMC = PG # type alias of PG as Conditional SMC """ @@ -221,9 +246,9 @@ const CSMC = PG # type alias of PG as Conditional SMC Return a `Sampler` object for the PG algorithm. """ -function Sampler(alg::PG, model::Model, s::Selector) +function Sampler(alg::PG, model::Model, s::Selector; specialize_after=1) info = Dict{Symbol, Any}() - state = PGState(model) + state = PGState(model; specialize_after=specialize_after) return Sampler(alg, info, s, state) end @@ -319,23 +344,27 @@ function DynamicPPL.assume( r = rand(dist) push!(vi, vn, r, dist, spl) elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) - setorder!(vi, vn, get_num_produce(vi)) + if length(vi[vn, dist]) == length(r) + vi[vn, dist] = r + unset_flag!(vi, vn, "del") + updategid!(vi, vn, spl) + else + DynamicPPL.removedel!(vi) + push!(vi, vn, r, dist, spl) + end else updategid!(vi, vn, spl) - r = vi[vn] + r = vi[vn, dist] end else # vn belongs to other sampler <=> conditionning on vn if haskey(vi, vn) - r = vi[vn] + r = vi[vn, dist] else r = rand(dist) push!(vi, vn, r, dist, Selector(:invalid)) end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) + lp = logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) acclogp!(vi, lp) end return r, 0 diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 9510d96857..3c7a767802 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -3,9 +3,9 @@ module Inference using ..Core using ..Core: logZ using ..Utilities -using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, - islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, AbstractSampler, +using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, set_namedtuple!, + islinked_and_trans, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, + settrans!, getvns, getinitdist, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, @@ -20,13 +20,14 @@ using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler using Bijectors: _debug using DocStringExtensions: TYPEDEF, TYPEDFIELDS +import BangBang import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADbackend import DynamicPPL: get_matching_type, - VarName, _getranges, _getindex, getval, _getvns + VarName, getval, getvns import EllipticalSliceSampling import Random import MCMCChains @@ -74,6 +75,12 @@ getADbackend(::Hamiltonian{AD}) where AD = AD() # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end +getindex(vi::MixedVarInfo, spl::Sampler{<:Hamiltonian}) = vi.tvi[spl] +function setindex!(vi::MixedVarInfo, val, spl::Sampler{<:Hamiltonian}) + vi.tvi[spl] = val + return vi +end + """ mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real) @@ -101,6 +108,22 @@ struct Transition{T, F<:AbstractFloat} lp :: F end +function Base.promote_type( + ::Type{Transition{T1, F1}}, + ::Type{Transition{T2, F2}}, +) where {T1, F1, T2, F2} + return Transition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{Transition{T, F}}, + t::Transition, +) where {T, F} + return Transition{T, F}(convert(T, t.θ), convert(F, t.lp)) +end + function Transition(spl::Sampler, nt::NamedTuple=NamedTuple()) theta = merge(tonamedtuple(spl.state.vi), nt) lp = getlogp(spl.state.vi) @@ -156,9 +179,10 @@ function AbstractMCMC.sample( model::AbstractModel, alg::InferenceAlgorithm, N::Integer; + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...) + return AbstractMCMC.sample(rng, model, Sampler(alg, model; specialize_after=specialize_after), N; kwargs...) end function AbstractMCMC.sample( @@ -216,9 +240,10 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, n_chains::Integer; + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; + return AbstractMCMC.sample(rng, model, Sampler(alg, model; specialize_after=specialize_after), parallel, N, n_chains; kwargs...) end @@ -246,10 +271,21 @@ function AbstractMCMC.sample( n_chains::Integer; chain_type=MCMCChains.Chains, progress=PROGRESS[], + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, SampleFromPrior(), parallel, N, n_chains; - chain_type=chain_type, progress=progress, kwargs...) + vi = VarInfo(model, specialize_after) + return AbstractMCMC.sample( + rng, + model, + SampleFromPrior(vi), + parallel, + N, + n_chains; + chain_type=chain_type, + progress=progress, + kwargs..., + ) end function AbstractMCMC.sample_init!( @@ -272,7 +308,6 @@ function initialize_parameters!( verbose::Bool=false, kwargs... ) - islinked(spl.state.vi, spl) && invlink!(spl.state.vi, spl) # Get `init_theta` if init_theta !== nothing verbose && @info "Using passed-in initial variable values" init_theta @@ -329,7 +364,8 @@ function flatten_namedtuple(nt::NamedTuple) if length(v) == 1 return [(string(k), v)] else - return mapreduce(vcat, zip(v[1], v[2])) do (vnval, vn) + init = Tuple{String, eltype(v[1])}[] + return mapreduce(vcat, zip(v[1], v[2]); init = init) do (vnval, vn) return collect(FlattenIterator(vn, vnval)) end end @@ -520,9 +556,12 @@ end """ A blank `AbstractSamplerState` that contains only `VarInfo` information. """ -mutable struct SamplerState{VIType<:VarInfo} <: AbstractSamplerState +mutable struct SamplerState{VIType<:AbstractVarInfo} <: AbstractSamplerState vi :: VIType end +function replace_varinfo(::SamplerState, vi::AbstractVarInfo) + return SamplerState(vi) +end ####################################### # Concrete algorithm implementations. # @@ -548,6 +587,7 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) end floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) +floatof(::Type{Real}) = Real floatof(::Type) = Real # fallback if type inference failed function get_matching_type( @@ -578,6 +618,7 @@ function get_matching_type( ) where {T, N, TV <: Array{T, N}} return Array{get_matching_type(spl, vi, T), N} end +#= function get_matching_type( spl::Sampler{<:Union{PG, SMC}}, vi, @@ -585,6 +626,7 @@ function get_matching_type( ) where {T, N, TV <: Array{T, N}} return TArray{T, N} end +=# ############## # Utilities # diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 27a3b3f541..238a9f5b79 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -25,19 +25,22 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() -mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState +mutable struct ESSState{V<:AbstractVarInfo} <: AbstractSamplerState vi::V end +function replace_varinfo(::ESSState, vi::AbstractVarInfo) + return ESSState(vi) +end -function Sampler(alg::ESS, model::Model, s::Selector) +function Sampler(alg::ESS, model::Model, s::Selector; specialize_after=1) # sanity check - vi = VarInfo(model) + vi = VarInfo(model, specialize_after) space = getspace(alg) - vns = _getvns(vi, s, Val(space)) + vns = getvns(vi, s, Val(space)) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") for vn in vns[1] - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("[ESS] only supports Gaussian prior distributions") end @@ -102,9 +105,9 @@ end function ESSModel(model::Model, spl::Sampler{<:ESS}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) μ = mapreduce(vcat, vns[1]) do vn - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) vectorize(dist, mean(dist)) end @@ -115,7 +118,7 @@ end function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel) spl = model.spl vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) set_flag!(vi, vns[1][1], "del") model.model(vi, spl) return vi[spl] diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 4b82cb934a..46db2cf078 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -42,21 +42,25 @@ function Gibbs(algs::GibbsComponent...) end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that the `Gibbs` sampler iterates through for each `step!`. """ -mutable struct GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState +mutable struct GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState vi::V samplers::S end -function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}) - return GibbsState(VarInfo(model), samplers) +function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}; specialize_after=1) + return GibbsState(VarInfo(model, specialize_after), samplers) end -function Sampler(alg::Gibbs, model::Model, s::Selector) +function replace_varinfo(s::GibbsState, vi::AbstractVarInfo) + return GibbsState(vi, s.samplers) +end + +function Sampler(alg::Gibbs, model::Model, s::Selector; specialize_after=1) # sanity check for space space = getspace(alg) # create tuple of samplers @@ -70,30 +74,26 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) end rerun = !(_alg isa MH) || prev_alg isa PG || prev_alg isa ESS selector = Selector(Symbol(typeof(_alg)), rerun) - Sampler(_alg, model, selector) + Sampler(_alg, model, selector; specialize_after=specialize_after) + end + varinfo = samplers[1].state.vi + samplers = map(samplers) do sampler + Sampler( + sampler.alg, + sampler.info, + sampler.selector, + replace_varinfo(sampler.state, varinfo), + ) end # create a state variable - state = GibbsState(model, samplers) + state = GibbsState(varinfo, samplers) # create the sampler info = Dict{Symbol, Any}() spl = Sampler(alg, info, s, state) # add Gibbs to gids for all variables - vi = spl.state.vi - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns - - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) - - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end - end + DynamicPPL.updategid!(varinfo, (spl, samplers...)) return spl end @@ -112,6 +112,27 @@ struct GibbsTransition{T,F,S<:AbstractVector} transitions::S end +function Base.promote_type( + ::Type{GibbsTransition{T1, F1, S1}}, + ::Type{GibbsTransition{T2, F2, S2}}, +) where {T1, F1, S1, T2, F2, S2} + return GibbsTransition{ + Union{T1, T2}, + promote_type(F1, F2), + promote_type(S1, S2), + } +end +function Base.convert( + ::Type{GibbsTransition{T, F, S}}, + t::GibbsTransition, +) where {T, F, S} + return GibbsTransition{T, F, S}( + convert(T, t.θ), + convert(F, t.lp), + convert(S, t.transitions), + ) +end + function GibbsTransition(spl::Sampler{<:Gibbs}, transitions::AbstractVector) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -188,25 +209,26 @@ function AbstractMCMC.step!( end # Do not store transitions of subsamplers -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition::GibbsTransition, ::Model, ::Sampler{<:Gibbs}, N::Integer; kwargs... ) - return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N) + ts = Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, 0) + sizehint!(ts, N) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::Vector{<:Transition}, - iteration::Integer, transition::GibbsTransition, + iteration::Integer, ::Model, ::Sampler{<:Gibbs}, ::Integer; kwargs... ) - transitions[iteration] = Transition(transition.θ, transition.lp) - return + return BangBang.push!!(transitions, Transition(transition.θ, transition.lp)) end diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 3b8b95c95e..1d827005a2 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -3,7 +3,7 @@ ### mutable struct HMCState{ - TV <: TypedVarInfo, + TV <: AbstractVarInfo, TTraj<:AHMC.AbstractTrajectory, TAdapt<:AHMC.Adaptation.AbstractAdaptor, PhType <: AHMC.PhasePoint @@ -17,6 +17,18 @@ mutable struct HMCState{ z :: PhType end +function replace_varinfo(s::HMCState, vi::AbstractVarInfo) + return HMCState( + vi, + s.eval_num, + s.i, + s.traj, + s.h, + s.adaptor, + s.z, + ) +end + ########################## # Hamiltonian Transition # ########################## @@ -27,6 +39,27 @@ struct HamiltonianTransition{T, NT<:NamedTuple, F<:AbstractFloat} stat :: NT end +function Base.promote_type( + ::Type{HamiltonianTransition{T1, NT1, F1}}, + ::Type{HamiltonianTransition{T2, NT2, F2}}, +) where {T1, NT1, F1, T2, NT2, F2} + return HamiltonianTransition{ + Union{T1, T2}, + promote_type(NT1, NT2), + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{HamiltonianTransition{T, NT, F}}, + t::HamiltonianTransition, +) where {T, NT, F} + return HamiltonianTransition{T, NT, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(NT, t.stat) + ) +end + function HamiltonianTransition(spl::Sampler{<:Hamiltonian}, t::AHMC.Transition) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -102,8 +135,8 @@ end function update_hamiltonian!(spl, model, n) metric = gen_metric(n, spl) - ℓπ = gen_logπ(spl.state.vi, spl, model) - ∂ℓπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) + ℓπ = gen_logπ(link(spl.state.vi), spl, model) + ∂ℓπ∂θ = gen_∂logπ∂θ(link(spl.state.vi), spl, model) spl.state.h = AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) return spl end @@ -125,24 +158,24 @@ function AbstractMCMC.sample_init!( initialize_parameters!(spl; verbose=verbose, kwargs...) if init_theta !== nothing # Doesn't support dynamic models - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) - else + elseif resume_from === nothing # Samples new values and sets trans to true, then computes the logp model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) while !isfinite(spl.state.z.ℓπ.value) || !isfinite(spl.state.z.ℓπ.gradient) model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) @@ -166,19 +199,19 @@ function AbstractMCMC.sample_init!( spl.alg.n_adapts = 0 end end - + # Convert to transformed space if we're using # non-Gibbs sampling. - if !islinked(spl.state.vi, spl) && spl.selector.tag == :default - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default - invlink!(spl.state.vi, spl) + if spl.selector.tag == :default + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + else + invlink!(spl.state.vi, spl, model) model(spl.state.vi, spl) end end -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, @@ -191,28 +224,26 @@ function AbstractMCMC.transitions_init( else n = N end - return Vector{typeof(transition)}(undef, n) + ts = Vector{typeof(transition)}(undef, 0) + sizehint!(ts, n) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::AbstractVector, - iteration::Integer, transition, + iteration::Integer, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, ::Integer; discard_adapt = true, kwargs... ) - if discard_adapt && isdefined(sampler.alg, :n_adapts) - if iteration > sampler.alg.n_adapts - transitions[iteration - sampler.alg.n_adapts] = transition - end - return + if discard_adapt && isdefined(sampler.alg, :n_adapts) && iteration <= sampler.alg.n_adapts + return transitions + else + return BangBang.push!!(transitions, transition) end - - transitions[iteration] = transition - return end """ @@ -371,11 +402,13 @@ end function Sampler( alg::Union{StaticHamiltonian, AdaptiveHamiltonian}, model::Model, - s::Selector=Selector() + s::Selector=Selector(); + specialize_after=1 ) info = Dict{Symbol, Any}() # Create an empty sampler state that just holds a typed VarInfo. - initial_state = SamplerState(VarInfo(model)) + varinfo = getspace(alg) === () && specialize_after > 0 ? TypedVarInfo(model) : VarInfo(model, specialize_after) + initial_state = SamplerState(varinfo) # Create an initial sampler, to get all the initialization out of the way. initial_spl = Sampler(alg, info, s, initial_state) @@ -411,20 +444,22 @@ function AbstractMCMC.step!( spl.state.eval_num = 0 Turing.DEBUG && @debug "current ϵ: $ϵ" + updategid!(spl.state.vi, spl) # When a Gibbs component if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." - link!(spl.state.vi, spl) - model(spl.state.vi, spl) + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) end # Get position and log density before transition - θ_old, log_density_old = spl.state.vi[spl], getlogp(spl.state.vi) + θ_old, θ_old_trans = spl.state.vi[spl], link(spl.state.vi)[spl] + log_density_old = getlogp(spl.state.vi) if spl.selector.tag != :default - update_hamiltonian!(spl, model, length(θ_old)) + update_hamiltonian!(spl, model, length(θ_old_trans)) resize!(spl.state.z.θ, length(θ_old)) - spl.state.z.θ .= θ_old + spl.state.z.θ .= θ_old_trans end # Transition @@ -443,17 +478,19 @@ function AbstractMCMC.step!( # Update `vi` based on acceptance if t.stat.is_accept - spl.state.vi[spl] = t.z.θ + link(spl.state.vi)[spl] = t.z.θ + invlink!(spl.state.vi, spl, model) setlogp!(spl.state.vi, t.stat.log_density) else spl.state.vi[spl] = θ_old + link(spl.state.vi)[spl] = θ_old_trans setlogp!(spl.state.vi, log_density_old) + DynamicPPL.setsynced!(spl.state.vi, true) end # Gibbs component specified cares # Transform the space back Turing.DEBUG && @debug "R -> X..." - spl.selector.tag != :default && invlink!(spl.state.vi, spl) return HamiltonianTransition(spl, t) end @@ -514,13 +551,13 @@ function DynamicPPL.assume( ) Turing.DEBUG && _debug("assuming...") updategid!(vi, vn, spl) - r = vi[vn] - # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) + r = vi[vn, dist] + # acclogp!(vi, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))) # r Turing.DEBUG && _debug("dist = $dist") Turing.DEBUG && _debug("vn = $vn") Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))") - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -532,9 +569,9 @@ function DynamicPPL.dot_assume( ) @assert length(dist) == size(var, 1) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:Hamiltonian}, @@ -544,9 +581,9 @@ function DynamicPPL.dot_assume( vi, ) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( @@ -606,11 +643,11 @@ function HMCState( vi = spl.state.vi # Link everything if needed. - !islinked(vi, spl) && link!(vi, spl) + link!(vi, spl, model) # Get the initial log pdf and gradient functions. - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = gen_logπ(vi, spl, model) + ∂logπ∂θ = gen_∂logπ∂θ(link(vi), spl, model) + logπ = gen_logπ(link(vi), spl, model) # Get the metric type. metricT = getmetricT(spl.alg) @@ -635,7 +672,7 @@ function HMCState( h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. # Unlink everything. - invlink!(vi, spl) + invlink!(vi, spl, model) return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) end diff --git a/src/inference/is.jl b/src/inference/is.jl index a7f515e364..dcd4cad5b6 100644 --- a/src/inference/is.jl +++ b/src/inference/is.jl @@ -31,7 +31,7 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() -mutable struct ISState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct ISState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V final_logevidence :: F end diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 764c8959b0..993f90b95f 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -36,13 +36,14 @@ end function Sampler( alg::MH, model::Model, - s::Selector=Selector() + s::Selector=Selector(); + specialize_after=1, ) # Set up info dict. info = Dict{Symbol, Any}() # Set up state struct. - state = SamplerState(VarInfo(model)) + state = SamplerState(VarInfo(model, specialize_after)) # Generate a sampler. return Sampler(alg, info, s, state) @@ -54,49 +55,12 @@ alg_str(::Sampler{<:MH}) = "MH" # Utility functions # ##################### -""" - set_namedtuple!(vi::VarInfo, nt::NamedTuple) - -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. -""" -function set_namedtuple!(vi::VarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - - n_vns = length(vns) - n_vals = length(vals) - v_isarr = vals isa AbstractArray - - if v_isarr && n_vals == 1 && n_vns > 1 - for (vn, val) in zip(vns, vals[1]) - vi[vn] = val isa AbstractArray ? val : [val] - end - elseif v_isarr && n_vals > 1 && n_vns == 1 - vi[vns[1]] = vals - elseif v_isarr && n_vals == n_vns > 1 - for (vn, val) in zip(vns, vals) - vi[vn] = [val] - end - elseif v_isarr && n_vals == 1 && n_vns == 1 - if vals[1] isa AbstractArray - vi[vns[1]] = vals[1] - else - vi[vns[1]] = [vals[1]] - end - elseif !(v_isarr) - vi[vns[1]] = [vals] - else - error("Cannot assign `NamedTuple` to `VarInfo`") - end - end -end - """ MHLogDensityFunction A log density function for the MH sampler. -This variant uses the `set_namedtuple!` function to update the `VarInfo`. +This variant uses the `set_namedtuple!` function to update the variables. """ struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH}} <: Function # Relax AMH.DensityModel? model::M @@ -135,21 +99,21 @@ Returns two `NamedTuples`. The first `NamedTuple` has symbols as keys and distri The second `NamedTuple` has model symbols as keys and their stored values as values. """ function dist_val_tuple(spl::Sampler{<:MH}) - vi = spl.state.vi - vns = _getvns(vi, spl) + vi = TypedVarInfo(spl.state.vi) + vns = getvns(vi, spl) dt = _dist_tuple(spl.alg.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt end @generated function _val_tuple( - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names} isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ - :($name = reconstruct(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)), + :($name = reconstruct(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)), DynamicPPL.getval(vi, vns.$name))) for name in names] return expr @@ -157,7 +121,7 @@ end @generated function _dist_tuple( props::NamedTuple{propnames}, - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names,propnames} isempty(names) === 0 && return :(NamedTuple()) @@ -168,7 +132,7 @@ end :($name = props.$name) else # Otherwise, use the default proposal. - :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)))) + :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)))) end for name in names] return expr end @@ -229,8 +193,8 @@ function DynamicPPL.assume( vi, ) updategid!(vi, vn, spl) - r = vi[vn] - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + r = vi[vn, dist] + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -244,9 +208,9 @@ function DynamicPPL.dot_assume( getvn = i -> VarName(vn, vn.indexing * "[:,$i]") vns = getvn.(1:size(var, 2)) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:MH}, @@ -258,9 +222,9 @@ function DynamicPPL.dot_assume( getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") vns = getvn.(CartesianIndices(var)) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( diff --git a/src/variational/advi.jl b/src/variational/advi.jl index c8e58e6c50..26c861a090 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -62,7 +62,7 @@ Creates a mean-field approximation with multivariate normal as underlying distri meanfield(model::Model) = meanfield(GLOBAL_RNG, model) function meanfield(rng::AbstractRNG, model::Model) # setup - varinfo = Turing.VarInfo(model) + varinfo = DynamicPPL.TypedVarInfo(model) num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)]) diff --git a/test/core/ad.jl b/test/core/ad.jl index 77cb06e0c6..21074b777c 100644 --- a/test/core/ad.jl +++ b/test/core/ad.jl @@ -1,5 +1,5 @@ using ForwardDiff, Distributions, FiniteDifferences, Tracker, Random, LinearAlgebra -using PDMats, Zygote +using PDMats, Zygote, ReverseDiff using Turing: Turing, invlink, link, SampleFromPrior, TrackerAD, ZygoteAD using DynamicPPL: getval @@ -18,8 +18,8 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...) ad_test_f = gdemo_default vi = Turing.VarInfo(ad_test_f) ad_test_f(vi, SampleFromPrior()) - svn = vi.metadata.s.vns[1] - mvn = vi.metadata.m.vns[1] + svn = vi.tvi.metadata.s.vns[1] + mvn = vi.tvi.metadata.m.vns[1] _s = getval(vi, svn)[1] _m = getval(vi, mvn)[1] diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 7320f703db..ecd6a1d376 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -7,6 +7,19 @@ using Random dir = splitdir(splitdir(pathof(Turing))[1])[1] include(dir*"/test/test_utils/AllUtils.jl") +struct DynamicDist <: DiscreteMultivariateDistribution end +function Distributions.logpdf(::DynamicDist, dsl_numeric::AbstractVector{Int}) + return sum([log(0.5) * 0.5^i for i in 1:length(dsl_numeric)]) +end +function Random.rand(rng::Random.AbstractRNG, ::DynamicDist) + fst = rand(rng, [0, 1]) + dsl_numeric = [fst] + while rand() < 0.5 + push!(dsl_numeric, rand(rng, [0, 1])) + end + return dsl_numeric +end + @testset "io.jl" begin # Only test threading if 1.3+. if VERSION > v"1.2" @@ -26,7 +39,7 @@ include(dir*"/test/test_utils/AllUtils.jl") # run sampler: progress logging should be disabled and # it should return a Chains object - sampler = Sampler(HMC(0.1, 7), gdemo_default) + sampler = Turing.Sampler(HMC(0.1, 7), gdemo_default) chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) @test chains isa MCMCChains.Chains end @@ -56,10 +69,10 @@ include(dir*"/test/test_utils/AllUtils.jl") chn2_contd = sample(gdemo_default, alg2, 1000; resume_from=chn2) check_gdemo(chn2_contd) - chn3 = sample(gdemo_default, alg3, 1000; save_state=true) + chn3 = sample(gdemo_default, alg3, 5000; save_state=true) check_gdemo(chn3) - chn3_contd = sample(gdemo_default, alg3, 1000; resume_from=chn3) + chn3_contd = sample(gdemo_default, alg3, 5000; resume_from=chn3) check_gdemo(chn3_contd) end @testset "Contexts" begin @@ -114,4 +127,26 @@ include(dir*"/test/test_utils/AllUtils.jl") @test mean(x[:s][1] for x in chains) ≈ 3 atol=0.1 @test mean(x[:m][1] for x in chains) ≈ 0 atol=0.1 end + @testset "stochastic control flow" begin + @model demo(p) = begin + x ~ Categorical(p) + if x == 1 + y ~ Normal() + elseif x == 2 + z ~ Normal() + else + k ~ Normal() + end + end + chain = sample(demo(fill(1/3, 3)), PG(4), 7000) + check_numerical(chain, [:x, :y, :z, :k], [2, 0, 0, 0], atol=0.05, skip_missing=true) + + chain = sample(demo(fill(1/3, 3)), Gibbs(PG(4, :x, :y), PG(4, :z, :k)), 7000) + check_numerical(chain, [:x, :y, :z, :k], [2, 0, 0, 0], atol=0.05, skip_missing=true) + + @model function mwe() + dsl ~ DynamicDist() + end + chain = sample(mwe(), PG(10), 500) + end end diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index 295d12c3d9..391fc63b81 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -40,7 +40,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(100) alg = Gibbs( CSMC(10, :s), - HMC(0.2, 4, :m)) + HMC(0.2, 4, :m), + ) chain = sample(gdemo(1.5, 2.0), alg, 3000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) @@ -48,13 +49,15 @@ include(dir*"/test/test_utils/AllUtils.jl") alg = Gibbs( MH(:s), - HMC(0.2, 4, :m)) + HMC(0.2, 4, :m), + ) chain = sample(gdemo(1.5, 2.0), alg, 5000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) alg = Gibbs( CSMC(15, :s), - ESS(:m)) + ESS(:m), + ) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) @@ -67,7 +70,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(200) gibbs = Gibbs( PG(10, :z1, :z2, :z3, :z4), - HMC(0.15, 3, :mu1, :mu2)) + HMC(0.15, 3, :mu1, :mu2), + ) chain = sample(MoGtest_default, gibbs, 1500) check_MoGtest_default(chain, atol = 0.15) @@ -76,7 +80,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(200) gibbs = Gibbs( PG(10, :z1, :z2, :z3, :z4), - ESS(:mu1), ESS(:mu2)) + ESS(:mu1), ESS(:mu2), + ) chain = sample(MoGtest_default, gibbs, 1500) check_MoGtest_default(chain, atol = 0.15) end @@ -134,7 +139,6 @@ include(dir*"/test/test_utils/AllUtils.jl") end end model = imm(randn(100), 1.0); - sample(model, Gibbs(MH(10, :z), HMC(0.01, 4, :m)), 100); sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m)), 100); end end diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index 6273e2ed47..563eb42bff 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -193,4 +193,70 @@ include(dir*"/test/test_utils/AllUtils.jl") @test sample(mwe(), HMC(0.2, 4), 1_000) isa Chains end + + @turing_testset "Stochastic support" begin + n = 10 + m = 10 + k = 4 + theta = randn(n) + b = zeros(k,m) + for i in 1:m + b[1,i] = randn() + for j in 2:k + dd = truncated(Normal(), b[j-1,i], Inf) + b[j,i] = rand(dd) + end + end + + logit = x -> log(x / (1 - x)) + invlogit = x -> exp(x)/(1 + exp(x)) + y = zeros(m,n) + probs = zeros(k,m,n) + for p in 1:n + for i in 1:m + probs[1,i,p] = 1.0 + for j in 1:(k-1) + Q = invlogit(theta[p] - b[j,i]) + probs[j,i,p] -= Q + probs[j+1,i,p] = Q + end + y[i,p] = rand(Categorical(probs[:,i,p])) + end + end + + # Graded Response Model + @model function grm(y, n, m, k, ::Type{TC}=Array{Float64,3}, ::Type{TM}=Array{Float64,2}, ::Type{TV}=Vector{Float64}) where {TC, TM, TV} + b = TM(undef, k, m) + for i in 1:m + b[1,i] ~ Normal(0,1) + for j in 2:k + b[j,i] ~ truncated(Normal(0,1), b[j-1,i], Inf) + end + end + probs = TC(undef, k, m, n) + theta = TV(undef, n) + for p in 1:n + theta[p] ~ Normal(0,1) + for i in 1:m + probs[1,i,p] = 1.0 + for j in 1:(k-1) + Q = invlogit(theta[p] - b[j,i]) + probs[j,i,p] -= Q + probs[j+1,i,p] = Q + end + probs[:,i,p] ./= sum(probs[:,i,p]) + y[i,p] ~ Categorical(probs[:,i,p], check_args=false) + end + end + return theta, b + end; + chn = sample(grm(y, n, m, k), HMC(0.05, 1), 100) + for c in 1:100 + for i in 1:m + for j in 2:k + @test chn["b[$j,$i]"].value[c] > chn["b[$(j-1),$i]"].value[c] + end + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 5776651bb6..f119d12224 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,37 +14,40 @@ include("test_utils/AllUtils.jl") include("core/ad.jl") include("core/container.jl") end - - test_adbackends = if VERSION >= v"1.2" - [:forwarddiff, :tracker, :reversediff] - else - [:forwarddiff, :tracker] - end - Turing.setrdcache(false) - for adbackend in test_adbackends - Turing.setadbackend(adbackend) - @testset "inference: $adbackend" begin - @testset "samplers" begin - include("inference/gibbs.jl") - include("inference/hmc.jl") - include("inference/is.jl") - include("inference/mh.jl") - include("inference/ess.jl") - include("inference/AdvancedSMC.jl") - include("inference/Inference.jl") - include("contrib/inference/dynamichmc.jl") + Turing.setadbackend(:forwarddiff) + @testset "inference" begin + @testset "samplers" begin + include("inference/gibbs.jl") + include("inference/is.jl") + include("inference/mh.jl") + include("inference/ess.jl") + include("inference/AdvancedSMC.jl") + include("inference/Inference.jl") + test_adbackends = if VERSION >= v"1.2" + [:forwarddiff, :tracker, :reversediff] + else + [:forwarddiff, :tracker] + end + Turing.setrdcache(false) + for adbackend in test_adbackends + @testset "hmc: $adbackend" begin + Turing.setadbackend(adbackend) + include("inference/hmc.jl") + include("contrib/inference/dynamichmc.jl") + end + @testset "variational algorithms : $adbackend" begin + include("variational/advi.jl") + end end - end - - @testset "variational algorithms : $adbackend" begin - include("variational/advi.jl") end end + + Turing.setadbackend(:forwarddiff) + @testset "variational optimisers" begin include("variational/optimisers.jl") end - Turing.setadbackend(:forwarddiff) @testset "stdlib" begin include("stdlib/distributions.jl") include("stdlib/RandomMeasures.jl") diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index bd684918fb..0582082304 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -78,7 +78,7 @@ function test_model_ad(model, f, syms::Vector{Symbol}) vnvals = Vector{Float64}() for i in 1:length(syms) s = syms[i] - vnms[i] = getfield(vi.metadata, s).vns[1] + vnms[i] = getfield(vi.tvi.metadata, s).vns[1] vals = getval(vi, vnms[i]) for i in eachindex(vals) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 88e2b4e9e7..0765a0a27a 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -41,12 +41,19 @@ end function check_numerical(chain, symbols::Vector, exact_vals::Vector; + skip_missing=false, atol=0.2, rtol=0.0) for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? - mean(chain[sym].value) : - vec(mean(chain[sym].value, dims=[1])) + if skip_missing + E = val isa Real ? + mean(skipmissing(chain[sym].value)) : + vec(mean(skipmissing(chain[sym].value), dims=[1])) + else + E = val isa Real ? + mean(chain[sym].value) : + vec(mean(chain[sym].value, dims=[1])) + end @info (symbol=sym, exact=val, evaluated=E) @test E ≈ val atol=atol rtol=rtol end diff --git a/test/test_utils/testing_functions.jl b/test/test_utils/testing_functions.jl index a7b22eaf57..5678364900 100644 --- a/test/test_utils/testing_functions.jl +++ b/test/test_utils/testing_functions.jl @@ -19,7 +19,7 @@ function randr(vi::Turing.VarInfo, else if count Turing.checkindex(vn, vi, spl) end Turing.updategid!(vi, vn, spl) - return vi[vn] + return vi[vn, dist] end end