diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c7db4873f..0b28b8a91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -28,6 +28,7 @@ import Base: Symbol, export AbstractVarInfo, VarInfo, UntypedVarInfo, + MixedVarInfo, getlogp, setlogp!, acclogp!, @@ -111,6 +112,7 @@ include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") include("threadsafe.jl") +include("mixedvarinfo.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") diff --git a/src/mixedvarinfo.jl b/src/mixedvarinfo.jl new file mode 100644 index 000000000..f350f9d38 --- /dev/null +++ b/src/mixedvarinfo.jl @@ -0,0 +1,206 @@ +struct MixedVarInfo{ + Ttvi <: Union{TypedVarInfo, Nothing}, + Tuvi <: UntypedVarInfo, +} <: AbstractVarInfo + tvi::Ttvi + uvi::Tuvi + is_uvi_empty::Base.RefValue{Bool} +end +MixedVarInfo(vi::TypedVarInfo) = MixedVarInfo(vi, VarInfo(), Ref(true)) +function MixedVarInfo(vi::UntypedVarInfo) + MixedVarInfo(TypedVarInfo(vi), empty!(deepcopy(vi)), Ref(true)) +end +function VarInfo(model::Model, ctx = DefaultContext()) + vi = VarInfo() + model(vi, SampleFromPrior(), ctx) + return MixedVarInfo(TypedVarInfo(vi)) +end +function VarInfo(old_vi::MixedVarInfo, spl, x::AbstractVector) + new_tvi = VarInfo(old_vi.tvi, spl, x) + return MixedVarInfo(new_tvi, old_vi.uvi, old_vi.is_uvi_empty) +end +function TypedVarInfo(vi::MixedVarInfo) + return VarInfo( + merge(vi.tvi.metadata, TypedVarInfo(vi.uvi).metadata), + Ref(getlogp(vi.tvi)), + Ref(get_num_produce(vi.tvi)), + ) +end + +_getvns(vi::MixedVarInfo, s::Selector, space) = _getvns(vi.tvi, s, space) + +function getmetadata(vi::MixedVarInfo, vn::VarName) + if haskey(vi.tvi, vn) + return getmetadata(vi.tvi, vn) + else + return getmetadata(vi.uvi, vn) + end +end +function Base.show(io::IO, vi::MixedVarInfo) + print(io, "Instance of MixedVarInfo") +end + +function fullyinspace(spl::AbstractSampler, vi::TypedVarInfo) + space = getspace(spl) + return space !== () && all(haskey.(Ref(vi.metadata), space)) +end + +acclogp!(vi::MixedVarInfo, logp) = acclogp!(vi.tvi, logp) +getlogp(vi::MixedVarInfo) = getlogp(vi.tvi) +resetlogp!(vi::MixedVarInfo) = resetlogp!(vi.tvi) +setlogp!(vi::MixedVarInfo, logp) = setlogp!(vi.tvi, logp) + +get_num_produce(vi::MixedVarInfo) = get_num_produce(vi.tvi) +increment_num_produce!(vi::MixedVarInfo) = increment_num_produce!(vi.tvi) +reset_num_produce!(vi::MixedVarInfo) = reset_num_produce!(vi.tvi) +set_num_produce!(vi::MixedVarInfo, n::Int) = set_num_produce!(vi.tvi, n) + +syms(vi::MixedVarInfo) = (syms(vi.tvi)..., syms(vi.uvi)...) + +function setgid!(vi::MixedVarInfo, gid::Selector, vn::VarName) + hassymbol(vi.tvi, vn) ? setgid!(vi.tvi, gid, vn) : setgid!(vi.uvi, gid, vn) + return vi +end +function setorder!(vi::MixedVarInfo, vn::VarName, index::Int) + hassymbol(vi.tvi, vn) ? setorder!(vi.tvi, vn, index) : setorder!(vi.uvi, vn, index) + return vi +end +function setval!(vi::MixedVarInfo, val, vn::VarName) + hassymbol(vi.tvi, vn) ? setval!(vi.tvi, val, vn) : setval!(vi.uvi, val, vn) + return vi +end + +function haskey(vi::MixedVarInfo, vn::VarName) + return hassymbol(vi.tvi, vn) ? haskey(vi.tvi, vn) : haskey(vi.uvi, vn) +end + +function link!(vi::MixedVarInfo, spl::AbstractSampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + link!(vi.tvi, spl) + else + link!(vi.tvi, spl) + link!(vi.uvi, spl) + end + return vi +end +function invlink!(vi::MixedVarInfo, spl::AbstractSampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + invlink!(vi.tvi, spl) + else + invlink!(vi.tvi, spl) + invlink!(vi.uvi, spl) + end + return vi +end +function islinked(vi::MixedVarInfo, spl::AbstractSampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + return islinked(vi.tvi, spl) + else + return islinked(vi.tvi, spl) || islinked(vi.uvi, spl) + end +end + +function getindex(vi::MixedVarInfo, vn::VarName) + return hassymbol(vi.tvi, vn) ? getindex(vi.tvi, vn) : getindex(vi.uvi, vn) +end +# All the VarNames have the same symbol +function getindex(vi::MixedVarInfo, vns::Vector{<:VarName{s}}) where {s} + return hassymbol(vi.tvi, vns[1]) ? getindex(vi.tvi, vns) : getindex(vi.uvi, vns) +end + +for splT in (:SampleFromPrior, :SampleFromUniform, :AbstractSampler) + @eval begin + function getindex(vi::MixedVarInfo, spl::$splT) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + return vi.tvi[spl] + else + return vcat(vi.tvi[spl], vi.uvi[spl]) + end + end + + function setindex!(vi::MixedVarInfo, val, spl::$splT) + if fullyinspace(spl, vi.tvi) + setindex!(vi.tvi, val, spl) + else + # TODO: define length(vi::TypedVarInfo, spl) + n = length(vi.tvi[spl]) + setindex!(vi.tvi, val[1:n], spl) + if n < length(val) + setindex!(vi.uvi, val[n+1:end], spl) + end + end + return vi + end + end +end + +function getall(vi::MixedVarInfo) + if vi.is_empty_uvi[] + return getall(vi.tvi) + else + return vcat(getall(vi.tvi), getall(vi.uvi)) + end +end + +function set_retained_vns_del_by_spl!(vi::MixedVarInfo, spl::Sampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + set_retained_vns_del_by_spl!(vi.tvi, spl) + else + set_retained_vns_del_by_spl!(vi.tvi, spl) + set_retained_vns_del_by_spl!(vi.uvi, spl) + end + return vi +end + +isempty(vi::MixedVarInfo) = isempty(vi.tvi) && vi.is_uvi_empty[] +function empty!(vi::MixedVarInfo) + empty!(vi.tvi) + vi.is_uvi_empty[] || empty!(vi.uvi) + vi.is_uvi_empty[] = true + return vi +end + +function push!( + vi::MixedVarInfo, + vn::VarName, + r, + dist::Distribution, + gidset::Set{Selector} +) + if hassymbol(vi.tvi, vn) + push!(vi.tvi, vn, r, dist, gidset) + else + push!(vi.uvi, vn, r, dist, gidset) + vi.is_uvi_empty[] = false + end + return vi +end + +function unset_flag!(vi::MixedVarInfo, vn::VarName, flag::String) + hassymbol(vi.tvi, vn) ? unset_flag!(vi.tvi, vn, flag) : unset_flag!(vi.uvi, vn, flag) + return vi +end +function is_flagged(vi::MixedVarInfo, vn::VarName, flag::String) + if hassymbol(vi.tvi, vn) + return is_flagged(vi.tvi, vn, flag) + else + return is_flagged(vi.uvi, vn, flag) + end +end + +function updategid!(vi::MixedVarInfo, spls::Tuple{Vararg{AbstractSampler}}) + foreach(spls) do spl + if fullyinspace(spl, vi.tvi) || vi.is_empty_uvi[] + updategid!(vi.tvi, spls) + else + updategid!(vi.uvi, spls) + end + end + return vi +end + +function tonamedtuple(vi::MixedVarInfo) + t1 = tonamedtuple(vi.tvi) + return vi.is_uvi_empty[] ? t1 : merge(t1, tonamedtuple(vi.uvi)) +end +set_namedtuple!(vi::MixedVarInfo, nt::NamedTuple) = set_namedtuple!(vi.tvi, nt) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 78ab465aa..1e259b4a6 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -43,11 +43,11 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names end if isdefined(ntr.chain.info, :vi) _vi = ntr.chain.info.vi - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) elseif isdefined(ntr, :varinfo) _vi = ntr.varinfo - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) else vi = nothing @@ -62,7 +62,7 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names modelgen = ntr.model if isdefined(ntr, :varinfo) _vi = ntr.varinfo - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) else vi = nothing @@ -115,6 +115,8 @@ end missing_arg_error_msg(arg, ::Missing) = """Variable $arg has a value of `missing`, or is not defined and its default value is `missing`. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has no default value. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" +warn_msg(arg::Symbol) = "Argument $arg is not defined. A value of `nothing` is used." + function logprior( left::NamedTuple, right::NamedTuple, @@ -134,7 +136,7 @@ function logprior( # When all of model args are on the lhs of |, this is also equal to the logjoint. model = make_prior_model(left, right, modelgen) - vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi + vi = _vi === nothing ? TypedVarInfo(deepcopy(model), PriorContext()) : _vi foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end @@ -173,8 +175,6 @@ end end end -warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used." - function loglikelihood( left::NamedTuple, right::NamedTuple, @@ -182,7 +182,7 @@ function loglikelihood( _vi::Union{Nothing, VarInfo}, ) model = make_likelihood_model(left, right, modelgen) - vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi + vi = _vi === nothing ? TypedVarInfo(deepcopy(model)) : TypedVarInfo(_vi) if isdefined(right, :chain) # Element-wise likelihood for each value in chain chain = right.chain diff --git a/src/varinfo.jl b/src/varinfo.jl index 077ef1a95..049384cbf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -97,7 +97,7 @@ Note: It is the user's responsibility to ensure that each "symbol" is visited at once whenever the model is called, regardless of any stochastic branching. Each symbol refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. """ -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta <: Union{Metadata, NamedTuple}, Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} @@ -105,7 +105,7 @@ end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} -function VarInfo(model::Model, ctx = DefaultContext()) +function TypedVarInfo(model::Model, ctx = DefaultContext()) vi = VarInfo() model(vi, SampleFromPrior(), ctx) return TypedVarInfo(vi) @@ -234,18 +234,18 @@ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) """ - getidx(vi::VarInfo, vn::VarName) + getidx(vi::AbstractVarInfo, vn::VarName) Return the index of `vn` in the metadata of `vi` corresponding to `vn`. """ -getidx(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] +getidx(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] """ getrange(vi::VarInfo, vn::VarName) Return the index range of `vn` in the metadata of `vi`. """ -getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] +getrange(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] """ getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) @@ -257,11 +257,11 @@ function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) end """ - getdist(vi::VarInfo, vn::VarName) + getdist(vi::AbstractVarInfo, vn::VarName) Return the distribution from which `vn` was sampled in `vi`. """ -getdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] +getdist(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] """ getval(vi::VarInfo, vn::VarName) @@ -270,7 +270,7 @@ Return the value(s) of `vn`. The values may or may not be transformed to Euclidean space. """ -getval(vi::VarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) +getval(vi::AbstractVarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) """ setval!(vi::VarInfo, val, vn::VarName) @@ -279,7 +279,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val +setval!(vi::AbstractVarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val """ getval(vi::VarInfo, vns::Vector{<:VarName}) @@ -335,7 +335,7 @@ end Return the set of sampler selectors associated with `vn` in `vi`. """ -getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] +getgid(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ settrans!(vi::VarInfo, trans::Bool, vn::VarName) @@ -474,7 +474,7 @@ end Set `vn`'s value for `flag` to `true` in `vi`. """ -function set_flag!(vi::VarInfo, vn::VarName, flag::String) +function set_flag!(vi::AbstractVarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end @@ -485,7 +485,7 @@ end # VarInfo -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) +VarInfo(meta::Metadata=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) """ TypedVarInfo(vi::UntypedVarInfo) @@ -578,7 +578,7 @@ keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) """ istrans(vi::VarInfo, vn::VarName) @@ -896,7 +896,7 @@ variables `x` would return (x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) ``` """ -function tonamedtuple(vi::VarInfo) +function tonamedtuple(vi::TypedVarInfo) return tonamedtuple(vi.metadata, vi) end @generated function tonamedtuple(metadata::NamedTuple{names}, vi::VarInfo) where {names} @@ -907,6 +907,7 @@ end end return expr end +tonamedtuple(vi::UntypedVarInfo) = tonamedtuple(TypedVarInfo(vi)) @inline function findvns(vi, f_vns) if length(f_vns) == 0 @@ -915,7 +916,7 @@ end return map(vn -> vi[vn], f_vns) end -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler, SampleFromPrior}) +function Base.eltype(vi::AbstractVarInfo, spl::AbstractSampler) return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) end @@ -931,6 +932,14 @@ function haskey(vi::TypedVarInfo, vn::VarName) return getsym(vn) in fieldnames(Tmeta) && haskey(getmetadata(vi, vn).idcs, vn) end +""" + hassymbol(vi::VarInfo, vn::VarName) + +Check whether the symbol of `vn` has been sampled in `vi`. +""" +hassymbol(vi::VarInfo, vn::VarName) = haskey(vi, vn) +hassymbol(vi::TypedVarInfo, vn::VarName) = haskey(vi.metadata, getsym(vn)) + function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) vi_str = """ /======================================================================= @@ -1139,3 +1148,93 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) setgid!(vi, spl.selector, vn) end end +function updategid!(vi::TypedVarInfo, spls::Tuple{Vararg{AbstractSampler}}) + foreach(keys(vi.metadata)) do sym + for vn in vi.metadata[sym].vns + updategid!.(Ref(vi), Ref(vn), spls) + end + end + return vi +end +function updategid!(vi::UntypedVarInfo, spls::Tuple{Vararg{AbstractSampler}}) + for vn in vi.metadata.vns + updategid!.(Ref(vi), Ref(vn), spls) + end + return vi +end + +#= +""" + set_namedtuple!(vi::AbstractVarInfo, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of `vi`. +""" +function set_namedtuple!(vi::UntypedVarInfo, nt::NamedTuple) + for (n, vals) in pairs(nt) + vns = vi.metadata[n].vns + + n_vns = length(vns) + n_vals = length(vals) + v_isarr = vals isa AbstractArray + + if v_isarr && n_vals == 1 && n_vns > 1 + for (vn, val) in zip(vns, vals[1]) + vi[vn] = val isa AbstractArray ? val : [val] + end + elseif v_isarr && n_vals > 1 && n_vns == 1 + vi[vns[1]] = vals + elseif v_isarr && n_vals == n_vns > 1 + for (vn, val) in zip(vns, vals) + vi[vn] = [val] + end + elseif v_isarr && n_vals == 1 && n_vns == 1 + if vals[1] isa AbstractArray + vi[vns[1]] = vals[1] + else + vi[vns[1]] = [vals[1]] + end + elseif !(v_isarr) + vi[vns[1]] = [vals] + else + error("Cannot assign `NamedTuple` to `VarInfo`") + end + end +end +=# + +""" + set_namedtuple!(vi, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of `vi`. +""" +function set_namedtuple!(vi::TypedVarInfo, nt::NamedTuple) + for (n, vals) in pairs(nt) + vns = vi.metadata[n].vns + + n_vns = length(vns) + n_vals = length(vals) + v_isarr = vals isa AbstractArray + + if v_isarr && n_vals == 1 && n_vns > 1 + for (vn, val) in zip(vns, vals[1]) + vi[vn] = val isa AbstractArray ? val : [val] + end + elseif v_isarr && n_vals > 1 && n_vns == 1 + vi[vns[1]] = vals + elseif v_isarr && n_vals == n_vns > 1 + for (vn, val) in zip(vns, vals) + vi[vn] = [val] + end + elseif v_isarr && n_vals == 1 && n_vns == 1 + if vals[1] isa AbstractArray + vi[vns[1]] = vals[1] + else + vi[vns[1]] = [vals[1]] + end + elseif !(v_isarr) + vi[vns[1]] = [vals] + else + error("Cannot assign `NamedTuple` to `VarInfo`") + end + end +end diff --git a/test/Turing/Turing.jl b/test/Turing/Turing.jl index c83f663c6..f24822c47 100644 --- a/test/Turing/Turing.jl +++ b/test/Turing/Turing.jl @@ -11,12 +11,11 @@ module Turing using Requires, Reexport, ForwardDiff using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions using Statistics, LinearAlgebra -using Markdown, Libtask, MacroTools -@reexport using Distributions, MCMCChains, Libtask +using Libtask +@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC using Tracker: Tracker -import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex! -import DynamicPPL: getspace +import DynamicPPL: getspace, NoDist, NamedDist const PROGRESS = Ref(true) function turnprogress(switch::Bool) @@ -68,6 +67,8 @@ export @model, # modelling @varname, DynamicPPL, + Prior, # Sampling from the prior + MH, # classic sampling RWMH, ESS, @@ -90,7 +91,6 @@ export @model, # modelling ADVI, sample, # inference - psample, setchunksize, resume, @logprob_str, @@ -105,15 +105,10 @@ export @model, # modelling Flat, FlatPos, BinomialLogit, - VecBinomialLogit, + BernoulliLogit, OrderedLogistic, LogPoisson, NamedDist, filldist, arraydist - -# Reexports -using AbstractMCMC: sample, psample -export sample, psample - end diff --git a/test/Turing/contrib/inference/dynamichmc.jl b/test/Turing/contrib/inference/dynamichmc.jl index 16eda6d84..e778a03d4 100644 --- a/test/Turing/contrib/inference/dynamichmc.jl +++ b/test/Turing/contrib/inference/dynamichmc.jl @@ -41,12 +41,12 @@ function DynamicNUTS{AD}(space::Symbol...) where AD DynamicNUTS{AD, space}() end -mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState +mutable struct DynamicNUTSState{V<:AbstractVarInfo, D} <: AbstractSamplerState vi::V draws::Vector{D} end -getspace(::DynamicNUTS{<:Any, space}) where {space} = space +DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space function AbstractMCMC.sample_init!( rng::AbstractRNG, @@ -61,8 +61,15 @@ function AbstractMCMC.sample_init!( end model(spl.state.vi, SampleFromUniform()) + link!(spl.state.vi, spl) + l, dl = _lp(spl.state.vi[spl]) + while !isfinite(l) || !isfinite(dl) + model(spl.state.vi, SampleFromUniform()) + link!(spl.state.vi, spl) + l, dl = _lp(spl.state.vi[spl]) + end - if spl.selector.tag == :default + if spl.selector.tag == :default && !islinked(spl.state.vi, spl) link!(spl.state.vi, spl) model(spl.state.vi, spl) end @@ -114,7 +121,7 @@ end model::AbstractModel, alg::DynamicNUTS, N::Integer; - chain_type=Chains, + chain_type=MCMCChains.Chains, resume_from=nothing, progress=PROGRESS[], kwargs... @@ -130,19 +137,20 @@ end end end -function AbstractMCMC.psample( +function AbstractMCMC.sample( rng::AbstractRNG, model::AbstractModel, alg::DynamicNUTS, + parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, n_chains::Integer; - chain_type=Chains, + chain_type=MCMCChains.Chains, progress=PROGRESS[], kwargs... ) if progress @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end - return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; - chain_type=chain_type, progress=false, kwargs...) + return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; + chain_type=chain_type, progress=false, kwargs...) end diff --git a/test/Turing/contrib/inference/sghmc.jl b/test/Turing/contrib/inference/sghmc.jl index 0940a1f57..83c488613 100644 --- a/test/Turing/contrib/inference/sghmc.jl +++ b/test/Turing/contrib/inference/sghmc.jl @@ -172,7 +172,7 @@ function step( spl.selector.tag != :default && link!(vi, spl) mssa = AHMC.Adaptation.ManualSSAdaptor(AHMC.Adaptation.MSSState(spl.alg.ϵ)) - spl.info[:adaptor] = AHMC.NaiveHMCAdaptor(AHMC.UnitPreconditioner(), mssa) + spl.info[:adaptor] = AHMC.NaiveHMCAdaptor(AHMC.UnitMassMatrix(), mssa) spl.selector.tag != :default && invlink!(vi, spl) return vi, true diff --git a/test/Turing/core/Core.jl b/test/Turing/core/Core.jl index 0227a6498..5f2c8812d 100644 --- a/test/Turing/core/Core.jl +++ b/test/Turing/core/Core.jl @@ -1,13 +1,12 @@ module Core using DistributionsAD, Bijectors -using MacroTools, Libtask, ForwardDiff, Random +using Libtask, ForwardDiff, Random using Distributions, LinearAlgebra using ..Utilities, Reexport using Tracker: Tracker using ..Turing: Turing -using DynamicPPL: Model, - AbstractSampler, Sampler, SampleFromPrior +using DynamicPPL: Model, AbstractSampler, Sampler, SampleFromPrior using LinearAlgebra: copytri! using Bijectors: PDMatDistribution import Bijectors: link, invlink @@ -17,9 +16,15 @@ using Requires include("container.jl") include("ad.jl") -@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD +function __init__() + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("compat/zygote.jl") + export ZygoteAD + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("compat/reversediff.jl") + export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache + end end export @model, @@ -37,8 +42,8 @@ export @model, current_trace, getweights, effectiveSampleSize, - increase_logweight, - inrease_logevidence, + increase_logweight!, + propagate!, resample!, ResampleWithESSThreshold, ADBackend, diff --git a/test/Turing/core/ad.jl b/test/Turing/core/ad.jl index 7bf81b195..7d86f76af 100644 --- a/test/Turing/core/ad.jl +++ b/test/Turing/core/ad.jl @@ -1,14 +1,23 @@ ############################## # Global variables/constants # ############################## -const ADBACKEND = Ref(:forward_diff) +const ADBACKEND = Ref(:forwarddiff) setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) function setadbackend(::Val{:forward_diff}) + Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) + setadbackend(Val(:forwarddiff)) +end +function setadbackend(::Val{:forwarddiff}) CHUNKSIZE[] == 0 && setchunksize(40) - ADBACKEND[] = :forward_diff + ADBACKEND[] = :forwarddiff end + function setadbackend(::Val{:reverse_diff}) - ADBACKEND[] = :reverse_diff + Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) + setadbackend(Val(:tracker)) +end +function setadbackend(::Val{:tracker}) + ADBACKEND[] = :tracker end const ADSAFE = Ref(false) @@ -37,8 +46,8 @@ struct TrackerAD <: ADBackend end ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) -ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:reverse_diff}) = TrackerAD +ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} +ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") """ @@ -51,7 +60,7 @@ getADbackend(spl::Sampler) = getADbackend(spl.alg) """ gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -62,7 +71,7 @@ tool is currently active. """ function gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::Sampler ) @@ -73,7 +82,7 @@ end gradient_logp( backend::ADBackend, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -84,7 +93,7 @@ specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{ function gradient_logp( ::ForwardDiffAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -111,7 +120,7 @@ end function gradient_logp( ::TrackerAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/test/Turing/core/compat/reversediff.jl b/test/Turing/core/compat/reversediff.jl new file mode 100644 index 000000000..f3822d35e --- /dev/null +++ b/test/Turing/core/compat/reversediff.jl @@ -0,0 +1,93 @@ +using .ReverseDiff: compile, GradientTape +using .ReverseDiff.DiffResults: GradientResult + +struct ReverseDiffAD{cache} <: ADBackend end +const RDCache = Ref(false) +setrdcache(b::Bool) = setrdcache(Val(b)) +setrdcache(::Val{false}) = RDCache[] = false +setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.") +function emptyrdcache end + +getrdcache() = RDCache[] +ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} +function setadbackend(::Val{:reversediff}) + ADBACKEND[] = :reversediff +end + +function gradient_logp( + backend::ReverseDiffAD{false}, + θ::AbstractVector{<:Real}, + vi::AbstractVarInfo, + model::Model, + sampler::AbstractSampler = SampleFromPrior(), +) + T = typeof(getlogp(vi)) + + # Specify objective function. + function f(θ) + new_vi = VarInfo(vi, sampler, θ) + model(new_vi, sampler) + return getlogp(new_vi) + end + tp, result = taperesult(f, θ) + ReverseDiff.gradient!(result, tp, θ) + l = DiffResults.value(result) + ∂l∂θ::typeof(θ) = DiffResults.gradient(result) + + return l, ∂l∂θ +end + +tape(f, x) = GradientTape(f, x) +function taperesult(f, x) + return tape(f, x), GradientResult(x) +end + +@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin + setrdcache(::Val{true}) = RDCache[] = true + function emptyrdcache() + for k in keys(Memoization.caches) + if k[1] === typeof(memoized_taperesult) + pop!(Memoization.caches, k) + end + end + end + function gradient_logp( + backend::ReverseDiffAD{true}, + θ::AbstractVector{<:Real}, + vi::AbstractVarInfo, + model::Model, + sampler::AbstractSampler = SampleFromPrior(), + ) + T = typeof(getlogp(vi)) + + # Specify objective function. + function f(θ) + new_vi = VarInfo(vi, sampler, θ) + model(new_vi, sampler) + return getlogp(new_vi) + end + ctp, result = memoized_taperesult(f, θ) + ReverseDiff.gradient!(result, ctp, θ) + l = DiffResults.value(result) + ∂l∂θ = DiffResults.gradient(result) + + return l, ∂l∂θ + end + + # This makes sure we generate a single tape per Turing model and sampler + struct RDTapeKey{F, Tx} + f::F + x::Tx + end + function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any}) + key = keys[1][1] + return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x))) + end + memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x)) + Memoization.@memoize function memoized_taperesult(k::RDTapeKey) + return compiledtape(k.f, k.x), GradientResult(k.x) + end + memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x)) + Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x) + compiledtape(f, x) = compile(GradientTape(f, x)) +end diff --git a/test/Turing/core/compat/zygote.jl b/test/Turing/core/compat/zygote.jl index 3c56a1922..dc18fa0f8 100644 --- a/test/Turing/core/compat/zygote.jl +++ b/test/Turing/core/compat/zygote.jl @@ -7,7 +7,7 @@ end function gradient_logp( backend::ZygoteAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/test/Turing/core/container.jl b/test/Turing/core/container.jl index 774aaaf15..6cecf7a6d 100644 --- a/test/Turing/core/container.jl +++ b/test/Turing/core/container.jl @@ -74,19 +74,16 @@ Data structure for particle filters - normalise!(pc::ParticleContainer) - consume(pc::ParticleContainer): return incremental likelihood """ -mutable struct ParticleContainer{T<:Particle, F} - model::F +mutable struct ParticleContainer{T<:Particle} + "Particles." vals::Vector{T} - # logarithmic weights (Trace) or incremental log-likelihoods (ParticleContainer) + "Unnormalized logarithmic weights." logWs::Vector{Float64} - # log model evidence - logE::Float64 - # helpful for rejuvenation steps, e.g. in SMC2 - n_consume::Int end -ParticleContainer(model, particles::Vector{<:Particle}) = - ParticleContainer(model, particles, zeros(length(particles)), 0.0, 0) +function ParticleContainer(particles::Vector{<:Particle}) + return ParticleContainer(particles, zeros(length(particles))) +end Base.collect(pc::ParticleContainer) = pc.vals Base.length(pc::ParticleContainer) = length(pc.vals) @@ -107,51 +104,56 @@ function Base.copy(pc::ParticleContainer) # copy weights logWs = copy(pc.logWs) - ParticleContainer(pc.model, vals, logWs, pc.logE, pc.n_consume) + ParticleContainer(vals, logWs) end -# run particle filter for one step, return incremental likelihood -function Libtask.consume(pc :: ParticleContainer) +""" + propagate!(pc::ParticleContainer) + +Run particle filter for one step and check if the final time step is reached. +""" +function propagate!(pc::ParticleContainer) # normalisation factor: 1/N - z1 = logZ(pc) n = length(pc) particles = collect(pc) - num_done = 0 - for i=1:n + numdone = 0 + for i in 1:n p = particles[i] score = Libtask.consume(p) if score isa Real score += getlogp(p.vi) resetlogp!(p.vi) - increase_logweight(pc, i, Float64(score)) + increase_logweight!(pc, i, Float64(score)) elseif score == Val{:done} - num_done += 1 + numdone += 1 else error("[consume]: error in running particle filter.") end end - if num_done == n - res = Val{:done} - elseif num_done != 0 - error("[consume]: mis-aligned execution traces, num_particles= $(n), num_done=$(num_done).") - else - # update incremental likelihoods - z2 = logZ(pc) - res = increase_logevidence(pc, z2 - z1) - pc.n_consume += 1 - # res = increase_loglikelihood(pc, z2 - z1) + # Check if all particles are propagated to the final time point. + numdone == n && return true + + # The posterior for models with random number of observations is not well-defined. + if numdone != 0 + error("mis-aligned execution traces: # particles = ", n, + " # completed trajectories = ", numdone, + ". Please make sure the number of observations is NOT random.") end - res + return false end # compute the normalized weights getweights(pc::ParticleContainer) = softmax(pc.logWs) -# compute the log-likelihood estimate, ignoring constant term ``- \log num_particles`` -logZ(pc::ParticleContainer) = logsumexp(pc.logWs) +""" + logZ(pc::ParticleContainer) + +Return the estimate of the log-likelihood ``p(y_t | y_{1:(t-1)}, \\theta)``. +""" +logZ(pc::ParticleContainer) = logsumexp(pc.logWs) - log(length(pc)) # compute the effective sample size ``1 / ∑ wᵢ²``, where ``wᵢ```are the normalized weights function effectiveSampleSize(pc :: ParticleContainer) @@ -159,12 +161,7 @@ function effectiveSampleSize(pc :: ParticleContainer) return inv(sum(abs2, Ws)) end -increase_logweight(pc :: ParticleContainer, t :: Int, logw :: Float64) = - (pc.logWs[t] += logw) - -increase_logevidence(pc :: ParticleContainer, logw :: Float64) = - (pc.logE += logw) - +increase_logweight!(pc::ParticleContainer, t::Int, logw::Float64) = (pc.logWs[t] += logw) function resample!( pc :: ParticleContainer, @@ -225,10 +222,9 @@ struct ResampleWithESSThreshold{R, T<:Real} threshold::T end -function ResampleWithESSThreshold() - ResampleWithESSThreshold(Turing.Inference.resample_systematic) +function ResampleWithESSThreshold(resampler = Turing.Inference.resample_systematic) + ResampleWithESSThreshold(resampler, 0.5) end -ResampleWithESSThreshold(resampler) = ResampleWithESSThreshold(resampler, 0.5) function resample!( pc::ParticleContainer, diff --git a/test/Turing/inference/AdvancedSMC.jl b/test/Turing/inference/AdvancedSMC.jl index 1f35de07c..1f05880c6 100644 --- a/test/Turing/inference/AdvancedSMC.jl +++ b/test/Turing/inference/AdvancedSMC.jl @@ -26,35 +26,50 @@ function additional_parameters(::Type{<:ParticleTransition}) return [:lp,:le, :weight] end +DynamicPPL.getlogp(t::ParticleTransition) = t.lp + #### #### Generic Sequential Monte Carlo sampler. #### """ - SMC() +$(TYPEDEF) Sequential Monte Carlo sampler. -Note that this method is particle-based, and arrays of variables -must be stored in a [`TArray`](@ref) object. - -Usage: +# Fields -```julia -SMC() -``` +$(TYPEDFIELDS) """ struct SMC{space, R} <: ParticleInference resampler::R end +""" + SMC(space...) + SMC([resampler = ResampleWithESSThreshold(), space = ()]) + SMC([resampler = resample_systematic, ]threshold[, space = ()]) + +Create a sequential Monte Carlo sampler of type [`SMC`](@ref) for the variables in `space`. + +If the algorithm for the resampling step is not specified explicitly, systematic resampling +is performed if the estimated effective sample size per particle drops below 0.5. +""" function SMC(resampler = Turing.Core.ResampleWithESSThreshold(), space::Tuple = ()) - SMC{space, typeof(resampler)}(resampler) + return SMC{space, typeof(resampler)}(resampler) +end + +# Convenient constructors with ESS threshold +function SMC(resampler, threshold::Real, space::Tuple = ()) + return SMC(Turing.Core.ResampleWithESSThreshold(resampler, threshold), space) end -SMC(::Tuple{}) = SMC() +SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, space) + +# If only the space is defined SMC(space::Symbol...) = SMC(space) +SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space) -mutable struct SMCState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F @@ -63,7 +78,7 @@ end function SMCState(model::Model) vi = VarInfo(model) - particles = ParticleContainer(model, Trace[]) + particles = ParticleContainer(Trace[]) return SMCState(vi, 0.0, particles) end @@ -96,11 +111,19 @@ function AbstractMCMC.sample_init!( particles = T[Trace(model, spl, vi) for _ in 1:N] # create a new particle container - spl.state.particles = pc = ParticleContainer(model, particles) + spl.state.particles = pc = ParticleContainer(particles) - while consume(pc) !== Val{:done} + # Run particle filter. + logevidence = zero(spl.state.average_logevidence) + isdone = false + while !isdone resample!(pc, spl.alg.resampler) + isdone = propagate!(pc) + logevidence += logZ(pc) end + spl.state.average_logevidence = logevidence + + return end function AbstractMCMC.step!( @@ -124,7 +147,7 @@ function AbstractMCMC.step!( params = tonamedtuple(particle.vi) lp = getlogp(particle.vi) - return ParticleTransition(params, lp, pc.logE, Ws[iteration]) + return ParticleTransition(params, lp, spl.state.average_logevidence, Ws[iteration]) end #### @@ -157,7 +180,7 @@ function PG(n1::Int, space::Symbol...) PG(n1, resample_systematic, space) end -mutable struct PGState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F @@ -212,11 +235,15 @@ function AbstractMCMC.step!( end # create a new particle container - pc = ParticleContainer(model, particles) + pc = ParticleContainer(particles) # run the particle filter - while consume(pc) !== Val{:done} - resample!(pc, spl.alg.resampler, ref_particle) + logevidence = zero(spl.state.average_logevidence) + isdone = false + while !isdone + resample!(pc, spl.alg.resampler) + isdone = propagate!(pc) + logevidence += logZ(pc) end # pick a particle to be retained. @@ -229,7 +256,7 @@ function AbstractMCMC.step!( lp = getlogp(spl.state.vi) # update the master vi. - return ParticleTransition(params, lp, pc.logE, 1.0) + return ParticleTransition(params, lp, logevidence, 1.0) end function AbstractMCMC.sample_end!( @@ -246,14 +273,14 @@ function AbstractMCMC.sample_end!( loge = mean(t.le for t in ts) # If we already had a chain, grab the logevidence. - if resume_from isa Chains + if resume_from isa MCMCChains.Chains # pushfirst!(samples, resume_from.info[:samples]...) pre_loge = resume_from.logevidence # Calculate new log-evidence pre_n = length(resume_from) loge = (pre_loge * pre_n + loge * N) / (pre_n + N) elseif resume_from !== nothing - error("keyword argument `resume_from` has to be `nothing` or a `Chains` object") + error("keyword argument `resume_from` has to be `nothing` or a `MCMCChains.Chains` object") end # Store the logevidence. @@ -264,10 +291,10 @@ function DynamicPPL.assume( spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, - vi + ::Any ) vi = current_trace().vi - if DynamicPPL.inspace(vn, getspace(spl)) + if inspace(vn, spl) if ~haskey(vi, vn) r = rand(dist) push!(vi, vn, r, dist, spl) diff --git a/test/Turing/inference/Inference.jl b/test/Turing/inference/Inference.jl index ddd989b73..389d59242 100644 --- a/test/Turing/inference/Inference.jl +++ b/test/Turing/inference/Inference.jl @@ -1,12 +1,15 @@ module Inference -using ..Core, ..Utilities +using ..Core +using ..Core: logZ +using ..Utilities using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, AbstractSampler, + settrans!, _getvns, getdist, set_namedtuple!, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, - LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist + LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, + getspace, inspace using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra @@ -16,16 +19,17 @@ using Random: GLOBAL_RNG, AbstractRNG, randexp using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler using Bijectors: _debug -using MCMCChains: Chains +using DocStringExtensions: TYPEDEF, TYPEDFIELDS import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADbackend -import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type, +import DynamicPPL: get_matching_type, VarName, _getranges, _getindex, getval, _getvns import EllipticalSliceSampling import Random +import MCMCChains export InferenceAlgorithm, Hamiltonian, @@ -47,6 +51,7 @@ export InferenceAlgorithm, SMC, CSMC, PG, + Prior, assume, dot_assume, observe, @@ -66,6 +71,15 @@ abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD) getADbackend(::Hamiltonian{AD}) where AD = AD() +# Algorithm for sampling from the prior +struct Prior <: InferenceAlgorithm end + +getindex(vi::MixedVarInfo, spl::Sampler{<:Hamiltonian}) = vi.tvi[spl] +function setindex!(vi::MixedVarInfo, val, spl::Sampler{<:Hamiltonian}) + vi.tvi[spl] = val + return vi +end + """ mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real) @@ -103,6 +117,8 @@ function additional_parameters(::Type{<:Transition}) return [:lp] end +DynamicPPL.getlogp(t::Transition) = t.lp + ########################################## # Internal variable names for MCMCChains # ########################################## @@ -146,47 +162,106 @@ function AbstractMCMC.sample( model::AbstractModel, alg::InferenceAlgorithm, N::Integer; - chain_type=Chains, + kwargs... +) + return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::AbstractModel, + sampler::Sampler{<:InferenceAlgorithm}, + N::Integer; + chain_type=MCMCChains.Chains, resume_from=nothing, progress=PROGRESS[], kwargs... ) if resume_from === nothing - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; - chain_type=chain_type, progress=progress, kwargs...) + return AbstractMCMC.mcmcsample(rng, model, sampler, N; + chain_type=chain_type, progress=progress, kwargs...) else return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...) end end -function AbstractMCMC.psample( +function AbstractMCMC.sample( + rng::AbstractRNG, + model::AbstractModel, + alg::Prior, + N::Integer; + chain_type=MCMCChains.Chains, + resume_from=nothing, + progress=PROGRESS[], + kwargs... +) + if resume_from === nothing + return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N; + chain_type=chain_type, progress=progress, kwargs...) + else + return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...) + end +end + +function AbstractMCMC.sample( model::AbstractModel, alg::InferenceAlgorithm, + parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, n_chains::Integer; kwargs... ) - return AbstractMCMC.psample(Random.GLOBAL_RNG, model, alg, N, n_chains; kwargs...) + return AbstractMCMC.sample(Random.GLOBAL_RNG, model, alg, parallel, N, n_chains; + kwargs...) end -function AbstractMCMC.psample( +function AbstractMCMC.sample( rng::AbstractRNG, model::AbstractModel, alg::InferenceAlgorithm, + parallel::AbstractMCMC.AbstractMCMCParallel, + N::Integer, + n_chains::Integer; + kwargs... +) + return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; + kwargs...) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::AbstractModel, + sampler::Sampler{<:InferenceAlgorithm}, + parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, n_chains::Integer; - chain_type=Chains, + chain_type=MCMCChains.Chains, progress=PROGRESS[], kwargs... ) - return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; - chain_type=chain_type, progress=progress, kwargs...) + return AbstractMCMC.mcmcsample(rng, model, sampler, parallel, N, n_chains; + chain_type=chain_type, progress=progress, kwargs...) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::AbstractModel, + alg::Prior, + parallel::AbstractMCMC.AbstractMCMCParallel, + N::Integer, + n_chains::Integer; + chain_type=MCMCChains.Chains, + progress=PROGRESS[], + kwargs... +) + return AbstractMCMC.sample(rng, model, SampleFromPrior(), parallel, N, n_chains; + chain_type=chain_type, progress=progress, kwargs...) end function AbstractMCMC.sample_init!( ::AbstractRNG, - model::Model, - spl::Sampler, + model::AbstractModel, + spl::Sampler{<:InferenceAlgorithm}, N::Integer; kwargs... ) @@ -197,17 +272,6 @@ function AbstractMCMC.sample_init!( initialize_parameters!(spl; kwargs...) end -function AbstractMCMC.sample_end!( - ::AbstractRNG, - ::Model, - ::Sampler, - ::Integer, - ::Vector; - kwargs... -) - # Silence the default API function. -end - function initialize_parameters!( spl::Sampler; init_theta::Union{Nothing,Vector}=nothing, @@ -238,11 +302,19 @@ end # Chain making utilities # ########################## -function _params_to_array(ts::Vector, spl::Sampler) +""" + getparams(t) + +Return a named tuple of parameters. +""" +getparams(t) = t.θ +getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t)) + +function _params_to_array(ts) names_set = Set{String}() # Extract the parameter names and values from each transition. dicts = map(ts) do t - nms, vs = flatten_namedtuple(t.θ) + nms, vs = flatten_namedtuple(getparams(t)) for nm in nms push!(names_set, nm) end @@ -250,7 +322,7 @@ function _params_to_array(ts::Vector, spl::Sampler) return Dict(nms[j] => vs[j] for j in 1:length(vs)) end names = collect(names_set) - vals = [get(dicts[i], key, missing) for i in eachindex(dicts), + vals = [get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)] return names, vals @@ -270,7 +342,12 @@ function flatten_namedtuple(nt::NamedTuple) return [vn[1] for vn in names_vals], [vn[2] for vn in names_vals] end -function get_transition_extras(ts::Vector) +function get_transition_extras(ts::AbstractVector{<:VarInfo}) + valmat = reshape([getlogp(t) for t in ts], :, 1) + return ["lp"], valmat +end + +function get_transition_extras(ts::AbstractVector) # Get the extra field names from the sampler state type. # This handles things like :lp or :weight. extra_params = additional_parameters(eltype(ts)) @@ -310,56 +387,55 @@ function get_transition_extras(ts::Vector) return extra_names, valmat end -# Default Chains constructor. +getlogevidence(sampler) = missing +function getlogevidence(sampler::Sampler) + if isdefined(sampler.state, :average_logevidence) + return sampler.state.average_logevidence + elseif isdefined(sampler.state, :final_logevidence) + return sampler.state.final_logevidence + else + return missing + end +end + +# Default MCMCChains.Chains constructor. +# This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( rng::AbstractRNG, - model::Model, - spl::Sampler, + model::AbstractModel, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, N::Integer, ts::Vector, - chain_type::Type{Chains}; - discard_adapt::Bool=true, - save_state=false, + chain_type::Type{MCMCChains.Chains}; + save_state = false, kwargs... ) - # Check if we have adaptation samples. - if discard_adapt && :n_adapts in fieldnames(typeof(spl.alg)) - ts = ts[(spl.alg.n_adapts+1):end] - end - # Convert transitions to array format. # Also retrieve the variable names. - nms, vals = _params_to_array(ts, spl) + nms, vals = _params_to_array(ts) - # Get the values of the extra parameters in each Transition struct. + # Get the values of the extra parameters in each transition. extra_params, extra_values = get_transition_extras(ts) # Extract names & construct param array. nms = [nms; extra_params] parray = hcat(vals, extra_values) - # If the state field has average_logevidence or final_logevidence, grab that. - le = missing - if :average_logevidence in fieldnames(typeof(spl.state)) - le = getproperty(spl.state, :average_logevidence) - elseif :final_logevidence in fieldnames(typeof(spl.state)) - le = getproperty(spl.state, :final_logevidence) - end - - # Check whether to invlink! the varinfo - if islinked(spl.state.vi, spl) - invlink!(spl.state.vi, spl) - end + # Get the average or final log evidence, if it exists. + le = getlogevidence(spl) # Set up the info tuple. if save_state - info = (range = rng, model = model, spl = spl, vi = spl.state.vi) + info = (range = rng, model = model, spl = spl) else info = NamedTuple() end + # Conretize the array before giving it to MCMCChains. + parray = MCMCChains.concretize(parray) + # Chain construction. - return Chains( + return MCMCChains.Chains( parray, string.(nms), deepcopy(TURING_INTERNAL_VARS); @@ -369,10 +445,11 @@ function AbstractMCMC.bundle_samples( ) end +# This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( rng::AbstractRNG, - model::Model, - spl::Sampler, + model::AbstractModel, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, N::Integer, ts::Vector, chain_type::Type{Vector{NamedTuple}}; @@ -382,39 +459,46 @@ function AbstractMCMC.bundle_samples( ) nts = Vector{NamedTuple}(undef, N) - for (i,t) in enumerate(ts) - k = collect(keys(t.θ)) + for (i, t) in enumerate(ts) + params = getparams(t) + + k = collect(keys(params)) vs = [] - for v in values(t.θ) + for v in values(params) push!(vs, v[1]) end push!(k, :lp) - - - nts[i] = NamedTuple{tuple(k...)}(tuple(vs..., t.lp)) + + nts[i] = NamedTuple{tuple(k...)}(tuple(vs..., getlogp(t))) end return map(identity, nts) end -function save(c::Chains, spl::Sampler, model, vi, samples) +function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples) nt = NamedTuple{(:spl, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples)) return setinfo(c, merge(nt, c.info)) end -function resume(c::Chains, n_iter::Int; chain_type=Chains, progress=PROGRESS[], kwargs...) +function resume( + c::MCMCChains.Chains, + n_iter::Int; + chain_type=MCMCChains.Chains, + progress=PROGRESS[], + kwargs... +) @assert !isempty(c.info) "[Turing] cannot resume from a chain without state info" # Sample a new chain. - newchain = AbstractMCMC.sample( + newchain = AbstractMCMC.mcmcsample( c.info[:range], c.info[:model], c.info[:spl], n_iter; resume_from=c, reuse_spl_n=n_iter, - chain_type=Chains, + chain_type=MCMCChains.Chains, progress=progress, kwargs... ) @@ -425,7 +509,7 @@ end function set_resume!( s::Sampler; - resume_from::Union{Chains, Nothing}=nothing, + resume_from::Union{MCMCChains.Chains, Nothing}=nothing, kwargs... ) # If we're resuming, grab the sampler info. @@ -441,7 +525,7 @@ end """ A blank `AbstractSamplerState` that contains only `VarInfo` information. """ -mutable struct SamplerState{VIType<:VarInfo} <: AbstractSamplerState +mutable struct SamplerState{VIType<:AbstractVarInfo} <: AbstractSamplerState vi :: VIType end @@ -462,10 +546,10 @@ include("../contrib/inference/sghmc.jl") ################ for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs) - @eval getspace(::$alg{space}) where {space} = space + @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) - @eval getspace(::$alg{<:Any, space}) where {space} = space + @eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space end floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) @@ -480,27 +564,20 @@ function get_matching_type( end function get_matching_type( spl::AbstractSampler, - vi, - ::Type{<:AbstractFloat}, -) - return floatof(eltype(vi, spl)) -end -function get_matching_type( - spl::Sampler{<:Hamiltonian}, - vi, + vi, ::Type{<:Union{Missing, AbstractFloat}}, ) return Union{Missing, floatof(eltype(vi, spl))} end function get_matching_type( - spl::Sampler{<:Hamiltonian}, - vi, + spl::AbstractSampler, + vi, ::Type{<:AbstractFloat}, ) return floatof(eltype(vi, spl)) end function get_matching_type( - spl::Sampler{<:Hamiltonian}, + spl::AbstractSampler, vi, ::Type{TV}, ) where {T, N, TV <: Array{T, N}} @@ -518,12 +595,7 @@ end # Utilities # ############## -getspace(spl::Sampler) = getspace(spl.alg) -function ambiguity_error_msg() - return "Ambiguous `lhs .~ rhs` or `@. lhs ~ rhs` syntax. The broadcasting can either be - column-wise following the convention of Distributions.jl or element-wise following - Julia's general broadcasting semantics. Please make sure that the element type of `lhs` - is not a supertype of the support type of `AbstractVector` to eliminate ambiguity." -end +DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg) +DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg)) end # module diff --git a/test/Turing/inference/ess.jl b/test/Turing/inference/ess.jl index 58c0f090f..1fa53d194 100644 --- a/test/Turing/inference/ess.jl +++ b/test/Turing/inference/ess.jl @@ -25,7 +25,7 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() -mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState +mutable struct ESSState{V<:AbstractVarInfo} <: AbstractSamplerState vi::V end @@ -145,7 +145,7 @@ function Distributions.loglikelihood(model::ESSModel, f) end function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi) - if DynamicPPL.inspace(vn, getspace(sampler)) + if inspace(vn, sampler) return DynamicPPL.tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) else return DynamicPPL.tilde(ctx, SampleFromPrior(), right, vn, inds, vi) @@ -157,7 +157,7 @@ function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, l end function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi) - if DynamicPPL.inspace(vn, getspace(sampler)) + if inspace(vn, sampler) return DynamicPPL.dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi) else return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi) diff --git a/test/Turing/inference/gibbs.jl b/test/Turing/inference/gibbs.jl index 61e7d7f6c..9482ced94 100644 --- a/test/Turing/inference/gibbs.jl +++ b/test/Turing/inference/gibbs.jl @@ -42,12 +42,12 @@ function Gibbs(algs::GibbsComponent...) end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that the `Gibbs` sampler iterates through for each `step!`. """ -mutable struct GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState +mutable struct GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState vi::V samplers::S end @@ -81,23 +81,37 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) # add Gibbs to gids for all variables vi = spl.state.vi - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns + DynamicPPL.updategid!(vi, (spl, samplers...)) - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) + return spl +end - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end - end +""" + GibbsTransition - return spl +Fields: +- `θ`: The parameters for any given sample. +- `lp`: The log pdf for the sample's parameters. +- `transitions`: The transitions of the samplers. +""" +struct GibbsTransition{T,F,S<:AbstractVector} + θ::T + lp::F + transitions::S end +function GibbsTransition(spl::Sampler{<:Gibbs}, transitions::AbstractVector) + theta = tonamedtuple(spl.state.vi) + lp = getlogp(spl.state.vi) + return GibbsTransition(theta, lp, transitions) +end + +function additional_parameters(::Type{<:GibbsTransition}) + return [:lp] +end + +DynamicPPL.getlogp(t::GibbsTransition) = t.lp + # Initialize the Gibbs sampler. function AbstractMCMC.sample_init!( rng::AbstractRNG, @@ -132,39 +146,55 @@ function AbstractMCMC.step!( model::Model, spl::Sampler{<:Gibbs}, N::Integer, - transition; + transition::Union{Nothing,GibbsTransition}; kwargs... ) Turing.DEBUG && @debug "Gibbs stepping..." - time_elapsed = 0.0 - # Iterate through each of the samplers. - for local_spl in spl.state.samplers + transitions = map(enumerate(spl.state.samplers)) do (i, local_spl) Turing.DEBUG && @debug "$(typeof(local_spl)) stepping..." - Turing.DEBUG && @debug "recording old θ..." - # Update the sampler's VarInfo. local_spl.state.vi = spl.state.vi # Step through the local sampler. - time_elapsed_thin = - @elapsed trans = AbstractMCMC.step!(rng, model, local_spl, N, transition; kwargs...) + if transition === nothing + trans = AbstractMCMC.step!(rng, model, local_spl, N, nothing; kwargs...) + else + trans = AbstractMCMC.step!(rng, model, local_spl, N, transition.transitions[i]; + kwargs...) + end # After the step, update the master varinfo. spl.state.vi = local_spl.state.vi - # Uncomment when developing thinning functionality. - # Retrieve symbol to store this subsample. - # symbol_id = Symbol(local_spl.selector.gid) - # - # # Store the subsample. - # spl.state.subsamples[symbol_id][] = trans - - # Record elapsed time. - time_elapsed += time_elapsed_thin + trans end - return Transition(spl) + return GibbsTransition(spl, transitions) +end + +# Do not store transitions of subsamplers +function AbstractMCMC.transitions_init( + transition::GibbsTransition, + ::Model, + ::Sampler{<:Gibbs}, + N::Integer; + kwargs... +) + return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N) +end + +function AbstractMCMC.transitions_save!( + transitions::Vector{<:Transition}, + iteration::Integer, + transition::GibbsTransition, + ::Model, + ::Sampler{<:Gibbs}, + ::Integer; + kwargs... +) + transitions[iteration] = Transition(transition.θ, transition.lp) + return end diff --git a/test/Turing/inference/hmc.jl b/test/Turing/inference/hmc.jl index f7238e115..427acd3bd 100644 --- a/test/Turing/inference/hmc.jl +++ b/test/Turing/inference/hmc.jl @@ -3,7 +3,7 @@ ### mutable struct HMCState{ - TV <: TypedVarInfo, + TV <: AbstractVarInfo, TTraj<:AHMC.AbstractTrajectory, TAdapt<:AHMC.Adaptation.AbstractAdaptor, PhType <: AHMC.PhasePoint @@ -37,6 +37,7 @@ function additional_parameters(::Type{<:HamiltonianTransition}) return [:lp,:stat] end +DynamicPPL.getlogp(t::HamiltonianTransition) = t.lp ### ### Hamiltonian Monte Carlo samplers. @@ -78,7 +79,7 @@ end alg_str(::Sampler{<:Hamiltonian}) = "HMC" -HMC(args...) = HMC{ADBackend()}(args...) +HMC(args...; kwargs...) = HMC{ADBackend()}(args...; kwargs...) function HMC{AD}(ϵ::Float64, n_leapfrog::Int, ::Type{metricT}, space::Tuple) where {AD, metricT <: AHMC.AbstractMetric} return HMC{AD, space, metricT}(ϵ, n_leapfrog) end @@ -99,21 +100,52 @@ function HMC{AD}( return HMC{AD}(ϵ, n_leapfrog, metricT, space) end +function update_hamiltonian!(spl, model, n) + metric = gen_metric(n, spl) + ℓπ = gen_logπ(spl.state.vi, spl, model) + ∂ℓπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) + spl.state.h = AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) + return spl +end + function AbstractMCMC.sample_init!( rng::AbstractRNG, - model::Model, + model::AbstractModel, spl::Sampler{<:Hamiltonian}, N::Integer; verbose::Bool=true, resume_from=nothing, + init_theta=nothing, kwargs... ) - # Resume the sampler. set_resume!(spl; resume_from=resume_from, kwargs...) # Get `init_theta` initialize_parameters!(spl; verbose=verbose, kwargs...) + if init_theta !== nothing + # Doesn't support dynamic models + link!(spl.state.vi, spl) + model(spl.state.vi, spl) + theta = spl.state.vi[spl] + spl.state.z.θ .= theta + else + # Samples new values and sets trans to true, then computes the logp + model(empty!(spl.state.vi), SampleFromUniform()) + link!(spl.state.vi, spl) + theta = spl.state.vi[spl] + resize!(spl.state.z.θ, length(theta)) + spl.state.z.θ .= theta + update_hamiltonian!(spl, model, length(theta)) + while !isfinite(spl.state.z.ℓπ.value) || !isfinite(spl.state.z.ℓπ.gradient) + model(empty!(spl.state.vi), SampleFromUniform()) + link!(spl.state.vi, spl) + theta = spl.state.vi[spl] + resize!(spl.state.z.θ, length(theta)) + spl.state.z.θ .= theta + update_hamiltonian!(spl, model, length(theta)) + end + end # Set the default number of adaptations, if relevant. if spl.alg isa AdaptiveHamiltonian @@ -138,9 +170,49 @@ function AbstractMCMC.sample_init!( if !islinked(spl.state.vi, spl) && spl.selector.tag == :default link!(spl.state.vi, spl) model(spl.state.vi, spl) + elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default + invlink!(spl.state.vi, spl) + model(spl.state.vi, spl) end end +function AbstractMCMC.transitions_init( + transition, + ::AbstractModel, + sampler::Sampler{<:Hamiltonian}, + N::Integer; + discard_adapt = true, + kwargs... +) + if discard_adapt && isdefined(sampler.alg, :n_adapts) + n = max(0, N - sampler.alg.n_adapts) + else + n = N + end + return Vector{typeof(transition)}(undef, n) +end + +function AbstractMCMC.transitions_save!( + transitions::AbstractVector, + iteration::Integer, + transition, + ::AbstractModel, + sampler::Sampler{<:Hamiltonian}, + ::Integer; + discard_adapt = true, + kwargs... +) + if discard_adapt && isdefined(sampler.alg, :n_adapts) + if iteration > sampler.alg.n_adapts + transitions[iteration - sampler.alg.n_adapts] = transition + end + return + end + + transitions[iteration] = transition + return +end + """ HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64=0.0) @@ -337,22 +409,21 @@ function AbstractMCMC.step!( spl.state.eval_num = 0 Turing.DEBUG && @debug "current ϵ: $ϵ" - - # Gibbs component specified cares + + # When a Gibbs component if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." link!(spl.state.vi, spl) model(spl.state.vi, spl) - # Update Hamiltonian - metric = gen_metric(length(spl.state.vi[spl]), spl) - ∂logπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) - logπ = gen_logπ(spl.state.vi, spl, model) - spl.state.h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) end - # Get position and log density before transition θ_old, log_density_old = spl.state.vi[spl], getlogp(spl.state.vi) + if spl.selector.tag != :default + update_hamiltonian!(spl, model, length(θ_old)) + resize!(spl.state.z.θ, length(θ_old)) + spl.state.z.θ .= θ_old + end # Transition t = AHMC.step(rng, spl.state.h, spl.state.traj, spl.state.z) @@ -499,8 +570,8 @@ end #### function AHMCAdaptor(alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric; ϵ=alg.ϵ) - pc = AHMC.Preconditioner(metric) - da = AHMC.NesterovDualAveraging(alg.δ, ϵ) + pc = AHMC.MassMatrixAdaptor(metric) + da = AHMC.StepSizeAdaptor(alg.δ, ϵ) if iszero(alg.n_adapts) adaptor = AHMC.Adaptation.NoAdaptation() @@ -549,7 +620,7 @@ function HMCState( # Find good eps if not provided one if spl.alg.ϵ == 0.0 - ϵ = AHMC.find_good_eps(h, θ_init) + ϵ = AHMC.find_good_stepsize(h, θ_init) @info "Found initial step size" ϵ else ϵ = spl.alg.ϵ diff --git a/test/Turing/inference/is.jl b/test/Turing/inference/is.jl index a7f515e36..dcd4cad5b 100644 --- a/test/Turing/inference/is.jl +++ b/test/Turing/inference/is.jl @@ -31,7 +31,7 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() -mutable struct ISState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct ISState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V final_logevidence :: F end diff --git a/test/Turing/inference/mh.jl b/test/Turing/inference/mh.jl index ed1edb4fe..42c4b5b2f 100644 --- a/test/Turing/inference/mh.jl +++ b/test/Turing/inference/mh.jl @@ -12,8 +12,6 @@ function MH(space...) prop_syms = Symbol[] props = AMH.Proposal[] - check_support(dist) = insupport(dist, z) - for s in space if s isa Symbol push!(syms, s) @@ -23,9 +21,9 @@ function MH(space...) if s[2] isa AMH.Proposal push!(props, s[2]) elseif s[2] isa Distribution - push!(props, AMH.Proposal(AMH.Static(), s[2])) + push!(props, AMH.StaticProposal(s[2])) elseif s[2] isa Function - push!(props, AMH.Proposal(AMH.Static(), s[2])) + push!(props, AMH.StaticProposal(s[2])) end end end @@ -35,94 +33,62 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end -alg_str(::Sampler{<:MH}) = "MH" +function Sampler( + alg::MH, + model::Model, + s::Selector=Selector() +) + # Set up info dict. + info = Dict{Symbol, Any}() -################# -# MH Transition # -################# + # Set up state struct. + state = SamplerState(VarInfo(model)) -struct MHTransition{T, F<:AbstractFloat, M<:AMH.Transition} - θ :: T - lp :: F - mh_trans :: M + # Generate a sampler. + return Sampler(alg, info, s, state) end -function MHTransition(spl::Sampler{<:MH}, mh_trans::AMH.Transition) - theta = tonamedtuple(spl.state.vi) - return MHTransition(theta, mh_trans.lp, mh_trans) -end +alg_str(::Sampler{<:MH}) = "MH" ##################### # Utility functions # ##################### """ - set_namedtuple!(vi::VarInfo, nt::NamedTuple) - -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. -""" -function set_namedtuple!(vi::VarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - - n_vns = length(vns) - n_vals = length(vals) - v_isarr = vals isa AbstractArray - - if v_isarr && n_vals == 1 && n_vns > 1 - for (vn, val) in zip(vns, vals[1]) - vi[vn] = val isa AbstractArray ? val : [val] - end - elseif v_isarr && n_vals > 1 && n_vns == 1 - vi[vns[1]] = vals - elseif v_isarr && n_vals == n_vns > 1 - for (vn, val) in zip(vns, vals) - vi[vn] = [val] - end - elseif v_isarr && n_vals == 1 && n_vns == 1 - if vals[1] isa AbstractArray - vi[vns[1]] = vals[1] - else - vi[vns[1]] = [vals[1]] - end - elseif !(v_isarr) - vi[vns[1]] = [vals] - else - error("Cannot assign `NamedTuple` to `VarInfo`") - end - end -end + MHLogDensityFunction -""" - gen_logπ_mh(vi, spl::Sampler, model) +A log density function for the MH sampler. -Generate a log density function -- this variant uses the -`set_namedtuple!` function to update the `VarInfo`. +This variant uses the `set_namedtuple!` function to update the variables. """ -function gen_logπ_mh(spl::Sampler, model) - function logπ(x)::Float64 - vi = spl.state.vi - x_old, lj_old = vi[spl], getlogp(vi) - # vi[spl] = x - set_namedtuple!(vi, x) - model(vi) - lj = getlogp(vi) - vi[spl] = x_old - setlogp!(vi, lj_old) - return lj - end - return logπ -end - -function scalar_map(vi, vns::Vector{V}) where V<:VarName - vals = getindex(vi, vns) - if length(vals) == length(vns) == 1 - # It's a scalar! - return vals[1] - else - # Go home, vector, you're drunk. - return vals - end +struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH}} <: Function # Relax AMH.DensityModel? + model::M + sampler::S +end + +function (f::MHLogDensityFunction)(x)::Float64 + sampler = f.sampler + vi = sampler.state.vi + x_old, lj_old = vi[sampler], getlogp(vi) + # vi[sampler] = x + set_namedtuple!(vi, x) + f.model(vi) + lj = getlogp(vi) + vi[sampler] = x_old + setlogp!(vi, lj_old) + return lj +end + +# unpack a vector if possible +unvectorize(dists::AbstractVector) = length(dists) == 1 ? first(dists) : dists + +# possibly unpack and reshape samples according to the prior distribution +reconstruct(dist::Distribution, val::AbstractVector) = DynamicPPL.reconstruct(dist, val) +function reconstruct( + dist::AbstractVector{<:UnivariateDistribution}, + val::AbstractVector +) + return val end """ @@ -132,80 +98,44 @@ Returns two `NamedTuples`. The first `NamedTuple` has symbols as keys and distri The second `NamedTuple` has model symbols as keys and their stored values as values. """ function dist_val_tuple(spl::Sampler{<:MH}) - vns = _getvns(spl.state.vi, spl) - dt = _dist_tuple(spl.state.vi.metadata, spl.alg.proposals, spl.state.vi, vns) - vt = _val_tuple(spl.state.vi.metadata, spl.state.vi, vns) + vi = spl.state.vi + vns = _getvns(vi, spl) + dt = _dist_tuple(spl.alg.proposals, vi, vns) + vt = _val_tuple(vi, vns) return dt, vt end -@generated function _val_tuple(metadata::NamedTuple, vi, vns::NamedTuple{names}) where {names} - length(names) === 0 && return :(NamedTuple()) +@generated function _val_tuple( + vi, + vns::NamedTuple{names} +) where {names} + isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) - map(names) do f - push!(expr.args, Expr(:(=), f, :(scalar_map(vi, metadata.$f.vns)))) - end + expr.args = Any[ + :($name = reconstruct(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)), + DynamicPPL.getval(vi, vns.$name))) + for name in names] return expr end @generated function _dist_tuple( - metadata::NamedTuple, props::NamedTuple{propnames}, - vi, + vi, vns::NamedTuple{names} -) where {names, propnames} - length(names) === 0 && return :(NamedTuple()) +) where {names,propnames} + isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) - map(names) do f - if f in propnames + expr.args = Any[ + if name in propnames # We've been given a custom proposal, use that instead. - push!(expr.args, Expr(:(=), f, :(props.$f))) + :($name = props.$name) else # Otherwise, use the default proposal. - push!(expr.args, Expr(:(=), f, :(AMH.Proposal(AMH.Static(), metadata.$f.dists[1])))) - end - end + :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)))) + end for name in names] return expr end -################# -# Sampler state # -################# - -mutable struct MHState{V<:VarInfo} <: AbstractSamplerState - vi :: V - density_model :: AMH.DensityModel -end - -############################### -# Static MH (from prior only) # -############################### - -function Sampler( - alg::MH, - model::Model, - s::Selector=Selector() -) - # Set up info dict. - info = Dict{Symbol, Any}() - - # Make a varinfo. - vi = VarInfo(model) - - # Make a density model. - dm = AMH.DensityModel(x -> 0.0) - - # Set up state struct. - state = MHState(vi, dm) - - # Generate a sampler. - spl = Sampler(alg, info, s, state) - - # Update the density model. - spl.state.density_model = AMH.DensityModel(gen_logπ_mh(spl, model)) - - return spl -end - function AbstractMCMC.sample_init!( rng::AbstractRNG, model::Model, @@ -242,7 +172,8 @@ function AbstractMCMC.step!( prev_trans = AMH.Transition(vt, getlogp(spl.state.vi)) # Make a new transition. - trans = AbstractMCMC.step!(rng, spl.state.density_model, mh_sampler, 1, prev_trans) + densitymodel = AMH.DensityModel(MHLogDensityFunction(model, spl)) + trans = AbstractMCMC.step!(rng, densitymodel, mh_sampler, 1, prev_trans) # Update the values in the VarInfo. set_namedtuple!(spl.state.vi, trans.params) diff --git a/test/Turing/stdlib/distributions.jl b/test/Turing/stdlib/distributions.jl index 3b76fdac7..63ae2ea77 100644 --- a/test/Turing/stdlib/distributions.jl +++ b/test/Turing/stdlib/distributions.jl @@ -47,16 +47,6 @@ struct BinomialLogit{T<:Real, I<:Integer} <: DiscreteUnivariateDistribution logitp::T end -""" - BinomialLogit(n<:Real, I<:Integer) - -A multivariate binomial logit distribution. -""" -struct VecBinomialLogit{T<:Real, I<:Integer} <: DiscreteUnivariateDistribution - n::Vector{I} - logitp::Vector{T} -end - function logpdf_binomial_logit(n, logitp, k) logcomb = -StatsFuns.log1p(n) - SpecialFunctions.logbeta(n - k + 1, k + 1) return logcomb + k * logitp - n * StatsFuns.log1pexp(logitp) @@ -66,8 +56,17 @@ function Distributions.logpdf(d::BinomialLogit{<:Real}, k::Int) return logpdf_binomial_logit(d.n, d.logitp, k) end -function Distributions.logpdf(d::VecBinomialLogit{<:Real}, ks::Vector{<:Integer}) - return sum(logpdf_binomial_logit.(d.n, d.logitp, ks)) +function Distributions.pdf(d::BinomialLogit{<:Real}, k::Int) + return exp(logpdf_binomial_logit(d.n, d.logitp, k)) +end + +""" + BernoulliLogit(p<:Real) + +A univariate logit-parameterised Bernoulli distribution. +""" +function BernoulliLogit(logitp::Real) + return BinomialLogit(1, logitp) end """ @@ -134,3 +133,10 @@ end function Distributions.logpdf(lp::LogPoisson, k::Int) return k * lp.logλ - exp(lp.logλ) - SpecialFunctions.loggamma(k + 1) end + +Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool) = 0 +function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool) + return zeros(Int, size(x, 2)) +end +Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool) = 0 diff --git a/test/Turing/utilities/Utilities.jl b/test/Turing/utilities/Utilities.jl index 4f41ad853..c3489f0d5 100644 --- a/test/Turing/utilities/Utilities.jl +++ b/test/Turing/utilities/Utilities.jl @@ -1,9 +1,9 @@ module Utilities using DynamicPPL: AbstractSampler, Sampler +using DynamicPPL: init, inittrans, reconstruct, reconstruct!, vectorize using Distributions, Bijectors using StatsFuns, SpecialFunctions -using MCMCChains: Chains, setinfo import Distributions: sample export vectorize, @@ -12,11 +12,9 @@ export vectorize, Sample, Chain, init, - vectorize, set_resume!, FlattenIterator -include("robustinit.jl") include("helper.jl") end # module diff --git a/test/Turing/utilities/robustinit.jl b/test/Turing/utilities/robustinit.jl deleted file mode 100644 index f78394b61..000000000 --- a/test/Turing/utilities/robustinit.jl +++ /dev/null @@ -1,33 +0,0 @@ -# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni() = Real(2rand()) -randrealuni(args...) = map(Real, 2rand(args...)) - -const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution} - - -################################# -# Single-sample initialisations # -################################# - -init(dist::Transformable) = inittrans(dist) -init(dist::Distribution) = rand(dist) - -inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni()) -inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1])) -inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...)) - - -################################ -# Multi-sample initialisations # -################################ - -init(dist::Transformable, n::Int) = inittrans(dist, n) -init(dist::Distribution, n::Int) = rand(dist, n) - -inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n)) -function inittrans(dist::MultivariateDistribution, n::Int) - return invlink(dist, randrealuni(size(dist)[1], n)) -end -function inittrans(dist::MatrixDistribution, n::Int) - return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n]) -end diff --git a/test/Turing/utilities/stan-interface.jl b/test/Turing/utilities/stan-interface.jl index c8c666d34..76e2430e5 100644 --- a/test/Turing/utilities/stan-interface.jl +++ b/test/Turing/utilities/stan-interface.jl @@ -70,8 +70,8 @@ end function AHMCAdaptor(adaptor) if :engaged in fieldnames(typeof(adaptor)) # CmdStan.Adapt adaptor.engaged ? spl.alg.n_adapts : 0, - AHMC.Preconditioner(metric), - AHMC.NesterovDualAveraging(adaptor.gamma, + AHMC.MassMatrixAdaptor(metric), + AHMC.StepSizeAdaptor(adaptor.gamma, adaptor.t0, adaptor.kappa, adaptor.δ, init_ϵ), adaptor.init_buffer, adaptor.term_buffer, @@ -79,8 +79,8 @@ function AHMCAdaptor(adaptor) else # default adaptor @warn "Invalid adaptor type: $(typeof(adaptor)). Default adaptor is used instead." adaptor = AHMC.StanHMCAdaptor( - AHMC.Preconditioner(:DiagEuclideanMetric), - AHMC.NesterovDualAveraging(spl.alg.δ, init_ϵ) + AHMC.MassMatrixAdaptor(:DiagEuclideanMetric), + AHMC.StepSizeAdaptor(spl.alg.δ, init_ϵ) ) AHMC.initialize!(adaptor, spl.alg.n_adapts) adaptor diff --git a/test/Turing/variational/VariationalInference.jl b/test/Turing/variational/VariationalInference.jl index 1192d37f4..cbdb9f2f6 100644 --- a/test/Turing/variational/VariationalInference.jl +++ b/test/Turing/variational/VariationalInference.jl @@ -38,10 +38,53 @@ function __init__() else - vo(alg, q(θ), model, args...) end - out .= Zygote.gradient(f, θ) + y, back = Tracker.pullback(f, θ) + dy = back(1.0) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) return out end end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + function Variational.grad!( + vo, + alg::VariationalInference{<:Turing.ReverseDiffAD{false}}, + q, + model, + θ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + args... + ) + f(θ) = if (q isa VariationalPosterior) + - vo(alg, update(q, θ), model, args...) + else + - vo(alg, q(θ), model, args...) + end + tp = Turing.Core.tape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out + end + @require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" begin + function Variational.grad!( + vo, + alg::VariationalInference{<:Turing.ReverseDiffAD{true}}, + q, + model, + θ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + args... + ) + f(θ) = if (q isa VariationalPosterior) + - vo(alg, update(q, θ), model, args...) + else + - vo(alg, q(θ), model, args...) + end + ctp = Turing.Core.memoized_tape(f, θ) + ReverseDiff.gradient!(out, ctp, θ) + return out + end + end + end end export diff --git a/test/Turing/variational/advi.jl b/test/Turing/variational/advi.jl index 715789934..c8e58e6c5 100644 --- a/test/Turing/variational/advi.jl +++ b/test/Turing/variational/advi.jl @@ -44,7 +44,7 @@ function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) whe idx += varinfo.metadata[sym].ranges[end][end] end - bs = inv.(bijector.(tuple(dists...))) + bs = bijector.(tuple(dists...)) if sym2ranges return Stacked(bs, ranges), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 814a3d245..dcfb38412 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -1,3 +1,9 @@ +using DynamicPPL, .Turing, Distributions, ForwardDiff, Tracker, Zygote + +dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") +include(dir*"/test/test_util.jl") + @testset "ad.jl" begin @testset "logp" begin # Hand-written log probabilities for vector `x = [s, m]`. diff --git a/test/compiler.jl b/test/compiler.jl index 204ddeda3..a49123226 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -565,12 +565,12 @@ end vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test vi1.metadata.y.vns[1] == VarName(:y) - @test haskey(vi2.metadata, :y) - @test vi2.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1))) - @test haskey(vi3.metadata, :y) - @test vi3.metadata.y.vns[1] == VarName(:y, ((1,),)) + @test haskey(vi1.tvi.metadata, :y) + @test vi1.tvi.metadata.y.vns[1] == VarName(:y) + @test haskey(vi2.tvi.metadata, :y) + @test vi2.tvi.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1))) + @test haskey(vi3.tvi.metadata, :y) + @test vi3.tvi.metadata.y.vns[1] == VarName(:y, ((1,),)) end @testset "custom tilde" begin @model demo() = begin diff --git a/test/test_util.jl b/test/test_util.jl index d8a926656..05127ee2c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -1,5 +1,7 @@ +using DynamicPPL: TypedVarInfo + function test_model_ad(model, logp_manual) - vi = VarInfo(model) + vi = TypedVarInfo(model) model(vi, SampleFromPrior()) x = DynamicPPL.getall(vi) diff --git a/test/varinfo.jl b/test/varinfo.jl index 5ec939807..0d97a1828 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -468,25 +468,25 @@ include(dir*"/test/test_utils/AllUtils.jl") g_demo_f(vi, SampleFromPrior()) step!(Random.GLOBAL_RNG, g_demo_f, pg, 1) vi1 = pg.state.vi - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == + @test mapreduce(x -> x.gids, vcat, vi1.tvi.metadata) == [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()] @inferred g_demo_f(vi1, hmc) - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == + @test mapreduce(x -> x.gids, vcat, vi1.tvi.metadata) == [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])] g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) pg, hmc = g.state.samplers - vi = empty!(TypedVarInfo(vi)) + vi = empty!(MixedVarInfo(vi)) @inferred g_demo_f(vi, SampleFromPrior()) pg.state.vi = vi step!(Random.GLOBAL_RNG, g_demo_f, pg, 1) vi = pg.state.vi @inferred g_demo_f(vi, hmc) - @test vi.metadata.x.gids[1] == Set([pg.selector]) - @test vi.metadata.y.gids[1] == Set([pg.selector]) - @test vi.metadata.z.gids[1] == Set([pg.selector]) - @test vi.metadata.w.gids[1] == Set([hmc.selector]) - @test vi.metadata.u.gids[1] == Set([hmc.selector]) + @test vi.tvi.metadata.x.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.y.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.z.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.w.gids[1] == Set([hmc.selector]) + @test vi.tvi.metadata.u.gids[1] == Set([hmc.selector]) end end