diff --git a/Project.toml b/Project.toml index c0179083b..b0c31a67f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -22,6 +23,7 @@ julia = "1" [extras] AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] +test = ["AdvancedHMC", "AdvancedMH", "BangBang", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6f1ca2d7c..665420766 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel using Distributions using Bijectors using MacroTools +using Requires import AbstractMCMC import ZygoteRules @@ -28,6 +29,7 @@ import Base: Symbol, export AbstractVarInfo, VarInfo, UntypedVarInfo, + MixedVarInfo, getlogp, setlogp!, acclogp!, @@ -116,8 +118,9 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") -include("varinfo.jl") +include("varinfo/varinfo.jl") include("threadsafe.jl") +include("mixedvarinfo.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 2e91afd52..ca21fe556 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -378,7 +378,7 @@ Convert the `value` to the correct type for the `sampler` and the `vi` object. function matchingvalue(sampler, vi, value) T = typeof(value) if hasmissing(T) - return get_matching_type(sampler, vi, T)(value) + return convert(get_matching_type(sampler, vi, T), value) else return value end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ebba1e088..3537cbcb5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -24,14 +24,14 @@ function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi) end function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi) if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds) settrans!(vi, false, vn) end return _tilde(sampler, right, vn, vi) end function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds) settrans!(vi, false, vn) end return _tilde(sampler, NoDist(right), vn, vi) @@ -127,11 +127,11 @@ function assume( if spl isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(dist, spl) - vi[vn] = vectorize(dist, r) + vi[vn, dist] = r settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + r = vi[vn, dist] end else r = init(dist, spl) @@ -297,12 +297,12 @@ function get_and_set_val!( r = init(dist, spl, n) for i in 1:n vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) + vi[vn, dist] = r[:, i] settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else - r = vi[vns] + r = vi[vns, dist] end else r = init(dist, spl, n) @@ -330,12 +330,12 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) + vi[vn, dist] = r[i] settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else - r = reshape(vi[vec(vns)], size(vns)) + r = vi[vns, dists] end else f = (vn, dist) -> init(dist, spl) @@ -354,7 +354,7 @@ function set_val!( ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) - vi[vn] = val[:,i] + vi[vn, dist] = val[:,i] end return val end @@ -367,7 +367,7 @@ function set_val!( @assert size(val) == size(vns) foreach(CartesianIndices(val)) do ind dist = dists isa AbstractArray ? dists[ind] : dists - vi[vns[ind]] = vectorize(dist, val[ind]) + vi[vns[ind], dist] = val[ind] end return val end diff --git a/src/mixedvarinfo.jl b/src/mixedvarinfo.jl new file mode 100644 index 000000000..2c3a6caed --- /dev/null +++ b/src/mixedvarinfo.jl @@ -0,0 +1,278 @@ +struct MixedVarInfo{ + Ttvi <: Union{TypedVarInfo, Nothing}, + Tuvi <: UntypedVarInfo, +} <: AbstractVarInfo + tvi::Ttvi + uvi::Tuvi + is_uvi_empty::Base.RefValue{Bool} +end +MixedVarInfo(vi::TypedVarInfo) = MixedVarInfo(vi, VarInfo(), Ref(true)) +function MixedVarInfo(vi::UntypedVarInfo) + MixedVarInfo(TypedVarInfo(vi), empty!(deepcopy(vi)), Ref(true)) +end +function VarInfo(model::Model, ctx = DefaultContext()) + vi = VarInfo() + model(vi, SampleFromPrior(), ctx) + return MixedVarInfo(TypedVarInfo(vi)) +end +function VarInfo(model::Model, n::Integer, ctx = DefaultContext()) + if n == 0 + vi = VarInfo() + model(vi) + return vi + else + tvi = TypedVarInfo(model, n, ctx) + return MixedVarInfo(tvi) + end +end +function VarInfo(old_vi::MixedVarInfo, spl, x::AbstractVector) + new_tvi = VarInfo(old_vi.tvi, spl, x) + return MixedVarInfo(new_tvi, old_vi.uvi, old_vi.is_uvi_empty) +end +function TypedVarInfo(vi::MixedVarInfo) + @assert getmode(vi.tvi) === getmode(vi.uvi) + mode = getmode(vi.tvi) + fixed_support = has_fixed_support(vi.tvi) && has_fixed_support(vi.uvi) + synced = issynced(vi.tvi) && issynced(vi.uvi) + if vi.is_uvi_empty[] + return vi.tvi + else + return VarInfo( + merge(vi.tvi.metadata, TypedVarInfo(vi.uvi).metadata), + Ref(getlogp(vi.tvi)), + Ref(get_num_produce(vi.tvi)), + mode, + Ref(fixed_support), + Ref(synced), + ) + end +end + +getinferred(vi::MixedVarInfo) = getinferred(vi.tvi) + +function Base.merge(t1::MixedVarInfo, t2::MixedVarInfo) + return MixedVarInfo(merge(TypedVarInfo(t1), TypedVarInfo(t2)), VarInfo(), Ref(true)) +end +function Base.merge(t1::TypedVarInfo, t2::MixedVarInfo) + return MixedVarInfo(merge(t1, TypedVarInfo(t2)), VarInfo(), Ref(true)) +end +Base.merge(t1::MixedVarInfo, t2::TypedVarInfo) = merge(t2, t1) +function Base.merge(t1::UntypedVarInfo, t2::MixedVarInfo) + return MixedVarInfo(merge(TypedVarInfo(t1), TypedVarInfo(t2)), VarInfo(), Ref(true)) +end +Base.merge(t1::MixedVarInfo, t2::UntypedVarInfo) = merge(t2, t1) + +function getvns(vi::MixedVarInfo, s::Selector, ::Val{space}) where {space} + if space !== () && all(haskey.(Ref(vi.tvi.metadata), space)) + return getvns(vi.tvi, s, Val(space)) + else + return getvns(TypedVarInfo(vi), s, Val(space)) + end +end +getmode(vi::MixedVarInfo) = getmode(vi.tvi) + +function getmetadata(vi::MixedVarInfo, vn::VarName) + if haskey(vi.tvi, vn) + return getmetadata(vi.tvi, vn) + else + return getmetadata(vi.uvi, vn) + end +end +function Base.show(io::IO, vi::MixedVarInfo) + print(io, "Instance of MixedVarInfo") +end + +function fullyinspace(spl::AbstractSampler, vi::TypedVarInfo) + space = getspace(spl) + return space !== () && all(haskey.(Ref(vi.metadata), space)) +end + +acclogp!(vi::MixedVarInfo, logp) = acclogp!(vi.tvi, logp) +getlogp(vi::MixedVarInfo) = getlogp(vi.tvi) +resetlogp!(vi::MixedVarInfo) = resetlogp!(vi.tvi) +setlogp!(vi::MixedVarInfo, logp) = setlogp!(vi.tvi, logp) + +get_num_produce(vi::MixedVarInfo) = get_num_produce(vi.tvi) +increment_num_produce!(vi::MixedVarInfo) = increment_num_produce!(vi.tvi) +reset_num_produce!(vi::MixedVarInfo) = reset_num_produce!(vi.tvi) +set_num_produce!(vi::MixedVarInfo, n::Int) = set_num_produce!(vi.tvi, n) + +syms(vi::MixedVarInfo) = (syms(vi.tvi)..., syms(vi.uvi)...) + +function setgid!(vi::MixedVarInfo, gid::Selector, vn::VarName; overwrite=false) + hassymbol(vi.tvi, vn) ? setgid!(vi.tvi, gid, vn; overwrite=overwrite) : setgid!(vi.uvi, gid, vn; overwrite=overwrite) + return vi +end +function setorder!(vi::MixedVarInfo, vn::VarName, index::Int) + hassymbol(vi.tvi, vn) ? setorder!(vi.tvi, vn, index) : setorder!(vi.uvi, vn, index) + return vi +end +function setval!(vi::MixedVarInfo, val, vn::VarName) + hassymbol(vi.tvi, vn) ? setval!(vi.tvi, val, vn) : setval!(vi.uvi, val, vn) + return vi +end + +function haskey(vi::MixedVarInfo, vn::VarName) + return hassymbol(vi.tvi, vn) ? haskey(vi.tvi, vn) : haskey(vi.uvi, vn) +end + +Bijectors.link(vi::MixedVarInfo) = MixedVarInfo(link(vi.tvi), link(vi.uvi), vi.is_uvi_empty) +Bijectors.invlink(vi::MixedVarInfo) = MixedVarInfo(invlink(vi.tvi), invlink(vi.uvi), vi.is_uvi_empty) +initlink(vi::MixedVarInfo) = MixedVarInfo(initlink(vi.tvi), initlink(vi.uvi), vi.is_uvi_empty) +has_fixed_support(vi::MixedVarInfo) = has_fixed_support(vi.tvi) && has_fixed_support(vi.uvi) +function set_fixed_support!(vi::MixedVarInfo, b::Bool) + set_fixed_support!(vi.tvi, b) + return vi +end + +issynced(vi::MixedVarInfo) = issynced(vi.tvi) && issynced(vi.uvi) +function setsynced!(vi::MixedVarInfo, b::Bool) + setsynced!(vi.tvi, b) + setsynced!(vi.uvi, b) + return vi +end + +function removedel!(vi::MixedVarInfo) + if vi.is_uvi_empty[] + return MixedVarInfo(removedel!(vi.tvi), vi.uvi, vi.is_uvi_empty) + else + removedel!(vi.uvi) + if isempty(vi.uvi) + return MixedVarInfo(removedel!(vi.tvi), vi.uvi, Ref(true)) + else + return MixedVarInfo(removedel!(vi.tvi), vi.uvi, Ref(false)) + end + end +end + +function link!(vi::MixedVarInfo, spl::AbstractSampler, model) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + link!(vi.tvi, spl, model) + else + link!(vi.tvi, spl, model) + link!(vi.uvi, spl, model) + end + return vi +end +function invlink!(vi::MixedVarInfo, spl::AbstractSampler, model) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + invlink!(vi.tvi, spl, model) + else + invlink!(vi.tvi, spl, model) + invlink!(vi.uvi, spl, model) + end + return vi +end +function islinked(vi::MixedVarInfo, spl::AbstractSampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + return islinked(vi.tvi, spl) + else + return islinked(vi.tvi, spl) || islinked(vi.uvi, spl) + end +end + +function getindex(vi::MixedVarInfo, vn::VarName) + return hassymbol(vi.tvi, vn) ? getindex(vi.tvi, vn) : getindex(vi.uvi, vn) +end +# All the VarNames have the same symbol +function getindex(vi::MixedVarInfo, vns::Vector{<:VarName{s}}) where {s} + return hassymbol(vi.tvi, vns[1]) ? getindex(vi.tvi, vns) : getindex(vi.uvi, vns) +end + +for splT in (:SampleFromPrior, :SampleFromUniform, :AbstractSampler) + @eval begin + function getindex(vi::MixedVarInfo, spl::$splT) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + return vi.tvi[spl] + else + return vcat(vi.tvi[spl], copy.(vi.uvi[spl])) + end + end + + function setindex!(vi::MixedVarInfo, val, spl::$splT) + if fullyinspace(spl, vi.tvi) + setindex!(vi.tvi, val, spl) + else + # TODO: define length(vi::TypedVarInfo, spl) + n = length(vi.tvi[spl]) + setindex!(vi.tvi, val[1:n], spl) + if n < length(val) + setindex!(vi.uvi, val[n+1:end], spl) + end + end + return vi + end + end +end + +function getall(vi::MixedVarInfo) + if vi.is_uvi_empty[] + return getall(vi.tvi) + else + return vcat(getall(vi.tvi), copy.(getall(vi.uvi))) + end +end + +function set_retained_vns_del_by_spl!(vi::MixedVarInfo, spl::Sampler) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + set_retained_vns_del_by_spl!(vi.tvi, spl) + else + set_retained_vns_del_by_spl!(vi.tvi, spl) + set_retained_vns_del_by_spl!(vi.uvi, spl) + end + return vi +end + +isempty(vi::MixedVarInfo) = isempty(vi.tvi) && vi.is_uvi_empty[] +function empty!(vi::MixedVarInfo) + empty!(vi.tvi) + vi.is_uvi_empty[] || empty!(vi.uvi) + vi.is_uvi_empty[] = true + return vi +end + +function push!( + vi::MixedVarInfo, + vn::VarName, + r, + dist::Distribution, + gidset::Set{Selector} +) + if hassymbol(vi.tvi, vn) + push!(vi.tvi, vn, r, dist, gidset) + else + push!(vi.uvi, vn, r, dist, gidset) + vi.is_uvi_empty[] = false + end + return vi +end + +function unset_flag!(vi::MixedVarInfo, vn::VarName, flag::String) + hassymbol(vi.tvi, vn) ? unset_flag!(vi.tvi, vn, flag) : unset_flag!(vi.uvi, vn, flag) + return vi +end +function is_flagged(vi::MixedVarInfo, vn::VarName, flag::String) + if hassymbol(vi.tvi, vn) + return is_flagged(vi.tvi, vn, flag) + else + return is_flagged(vi.uvi, vn, flag) + end +end + +function updategid!(vi::MixedVarInfo, spl::AbstractSampler; overwrite=false) + if fullyinspace(spl, vi.tvi) || vi.is_uvi_empty[] + updategid!(vi.tvi, spl; overwrite=overwrite) + else + updategid!(vi.uvi, spl; overwrite=overwrite) + end + return vi +end + +function tonamedtuple(vi::MixedVarInfo) + if vi.is_uvi_empty[] + return tonamedtuple(vi.tvi) + else + return tonamedtuple(TypedVarInfo(vi)) + end +end +set_namedtuple!(vi::MixedVarInfo, nt::NamedTuple) = set_namedtuple!(vi.tvi, nt) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index b047f711f..228477057 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -43,11 +43,11 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names end if isdefined(ntr.chain.info, :vi) _vi = ntr.chain.info.vi - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) elseif isdefined(ntr, :varinfo) _vi = ntr.varinfo - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) else vi = nothing @@ -62,7 +62,7 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names modelgen = ntr.model if isdefined(ntr, :varinfo) _vi = ntr.varinfo - @assert _vi isa VarInfo + @assert _vi isa AbstractVarInfo vi = TypedVarInfo(_vi) else vi = nothing @@ -115,6 +115,8 @@ end missing_arg_error_msg(arg, ::Missing) = """Variable $arg has a value of `missing`, or is not defined and its default value is `missing`. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has no default value. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" +warn_msg(arg::Symbol) = "Argument $arg is not defined. A value of `nothing` is used." + function logprior( left::NamedTuple, right::NamedTuple, @@ -134,7 +136,7 @@ function logprior( # When all of model args are on the lhs of |, this is also equal to the logjoint. model = make_prior_model(left, right, modelgen) - vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi + vi = _vi === nothing ? TypedVarInfo(deepcopy(model), PriorContext()) : _vi foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end @@ -182,7 +184,7 @@ function Distributions.loglikelihood( _vi::Union{Nothing, VarInfo}, ) model = make_likelihood_model(left, right, modelgen) - vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi + vi = _vi === nothing ? TypedVarInfo(deepcopy(model)) : TypedVarInfo(_vi) if isdefined(right, :chain) # Element-wise likelihood for each value in chain chain = right.chain diff --git a/src/sampler.jl b/src/sampler.jl index e40156707..451c77d98 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,8 +1,21 @@ """ Robust initialization method for model parameters in Hamiltonian samplers. """ -struct SampleFromUniform <: AbstractSampler end -struct SampleFromPrior <: AbstractSampler end +struct SampleFromUniform{Tvi <: AbstractVarInfo} <: AbstractSampler + vi::Tvi +end +function SampleFromUniform(model::AbstractModel; specialize_after=1) + return SampleFromUniform(VarInfo(model; specialize_after=specialize_after)) +end +SampleFromUniform() = SampleFromUniform(VarInfo()) + +struct SampleFromPrior{Tvi <: AbstractVarInfo} <: AbstractSampler + vi::Tvi +end +function SampleFromPrior(model::AbstractModel; specialize_after=1) + return SampleFromPrior(VarInfo(model; specialize_after=specialize_after)) +end +SampleFromPrior() = SampleFromPrior(VarInfo()) getspace(::Union{SampleFromPrior, SampleFromUniform}) = () @@ -50,8 +63,7 @@ mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler state :: S end Sampler(alg) = Sampler(alg, Selector()) -Sampler(alg, model::Model) = Sampler(alg, model, Selector()) -Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s) +Sampler(alg, model::Model; specialize_after=1) = Sampler(alg, model, Selector(); specialize_after=specialize_after) # AbstractMCMC interface for SampleFromUniform and SampleFromPrior @@ -63,7 +75,9 @@ function AbstractMCMC.step!( transition; kwargs... ) - vi = VarInfo() - model(vi, sampler) - return vi + empty!(sampler.vi) + model(sampler.vi, sampler) + return sampler.vi end +getinferred(spl::Sampler) = getinferred(spl.state.vi) +Base.empty!(spl::Sampler) = empty!(spl.state.vi) \ No newline at end of file diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 996934c53..08a24f0b9 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -45,9 +45,10 @@ reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) +getinferred(vi::ThreadSafeVarInfo) = getinferred(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - setgid!(vi.varinfo, gid, vn) +function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName; overwrite=false) + setgid!(vi.varinfo, gid, vn; overwrite=overwrite) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -55,8 +56,8 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) +link!(vi::ThreadSafeVarInfo, spl::AbstractSampler, model) = link!(vi.varinfo, spl, model) +invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler, model) = invlink!(vi.varinfo, spl, model) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) diff --git a/src/varinfo.jl b/src/varinfo.jl deleted file mode 100644 index 077ef1a95..000000000 --- a/src/varinfo.jl +++ /dev/null @@ -1,1141 +0,0 @@ -# Constants for caching -const CACHERESET = 0b00 -const CACHEIDCS = 0b10 -const CACHERANGES = 0b01 - -#### -#### Types for typed and untyped VarInfo -#### - - - -#################### -# VarInfo metadata # -#################### - -""" -The `Metadata` struct stores some metadata about the parameters of the model. This helps -query certain information about a variable, such as its distribution, which samplers -sample this variable, its value and whether this value is transformed to real space or -not. - -Let `md` be an instance of `Metadata`: -- `md.vns` is the vector of all `VarName` instances. -- `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. -- `md.vns[md.idcs[vn]] == vn`. -- `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.gids[md.idcs[vn]]` is the set of algorithms used to sample `vn`. This is used in - the Gibbs sampling process. -- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. -- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. -- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the - value of `flag` corresponding to `vn`. - -To make `md::Metadata` type stable, all the `md.vns` must have the same symbol -and distribution type. However, one can have a Julia variable, say `x`, that is a -matrix or a hierarchical array sampled in partitions, e.g. -`x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)`, and is managed by -a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the -same type. Type unstable `Metadata` will still work but will have inferior performance. -When sampling, the first iteration uses a type unstable `Metadata` for all the -variables then a specialized `Metadata` is used for each symbol along with a function -barrier to make the rest of the sampling type stable. -""" -struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} - # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` - idcs :: TIdcs # Dict{<:VarName,Int} - - # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` - vns :: TVN # AbstractVector{<:VarName} - - # Vector of index ranges in `vals` corresponding to `vns` - # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` - ranges :: Vector{UnitRange{Int}} - - # Vector of values of all the univariate, multivariate and matrix variables - # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals :: TVal # AbstractVector{<:Real} - - # Vector of distributions correpsonding to `vns` - dists :: TDists # AbstractVector{<:Distribution} - - # Vector of sampler ids corresponding to `vns` - # Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` - gids :: TGIds # AbstractVector{Set{Selector}} - - # Number of `observe` statements before each random variable is sampled - orders :: Vector{Int} - - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` - flags :: Dict{String, BitVector} -end - -########### -# VarInfo # -########### - -""" -``` -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -``` - -A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of -`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used -for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If -`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each -symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows -for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. - -Note: It is the user's responsibility to ensure that each "symbol" is visited at least -once whenever the model is called, regardless of any stochastic branching. Each symbol -refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. -""" -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -const UntypedVarInfo = VarInfo{<:Metadata} -const TypedVarInfo = VarInfo{<:NamedTuple} - -function VarInfo(model::Model, ctx = DefaultContext()) - vi = VarInfo() - model(vi, SampleFromPrior(), ctx) - return TypedVarInfo(vi) -end - -function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) - new_vi = deepcopy(old_vi) - new_vi[spl] = x - return new_vi -end -function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) - md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))) -end -@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space} - exprs = [] - offset = :(0) - for f in names - mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(length($mdf.vals)) - push!(exprs, :($f = Metadata($mdf.idcs, - $mdf.vns, - $mdf.ranges, - x[($offset + 1):($offset + $len)], - $mdf.dists, - $mdf.gids, - $mdf.orders, - $mdf.flags - ) - ) - ) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -#### -#### Internal functions -#### - -""" - Metadata() - -Construct an empty type unstable instance of `Metadata`. -""" -function Metadata() - vals = Vector{Real}() - flags = Dict{String, BitVector}() - flags["del"] = BitVector() - flags["trans"] = BitVector() - - return Metadata( - Dict{VarName, Int}(), - Vector{VarName}(), - Vector{UnitRange{Int}}(), - vals, - Vector{Distribution}(), - Vector{Set{Selector}}(), - Vector{Int}(), - flags - ) -end - -""" - empty!(meta::Metadata) - -Empty the fields of `meta`. - -This is useful when using a sampling algorithm that assumes an empty `meta`, e.g. `SMC`. -""" -function empty!(meta::Metadata) - empty!(meta.idcs) - empty!(meta.vns) - empty!(meta.ranges) - empty!(meta.vals) - empty!(meta.dists) - empty!(meta.gids) - empty!(meta.orders) - for k in keys(meta.flags) - empty!(meta.flags[k]) - end - - return meta -end - -# Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. -if VERSION < v"1.1" - _tail(nt::NamedTuple{names}) where names = NamedTuple{Base.tail(names)}(nt) -else - _tail(nt::NamedTuple) = Base.tail(nt) -end - -const VarView = Union{Int, UnitRange, Vector{Int}} - -""" - getval(vi::UntypedVarInfo, vview::Union{Int, UnitRange, Vector{Int}}) - -Return a view `vi.vals[vview]`. -""" -getval(vi::UntypedVarInfo, vview::VarView) = view(vi.metadata.vals, vview) - -""" - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) - -Set the value of `vi.vals[vview]` to `val`. -""" -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val -function setval!(vi::UntypedVarInfo, val, vview::Vector{UnitRange}) - if length(vview) > 0 - vi.metadata.vals[[i for arr in vview for i in arr]] = val - end - return val -end - -""" - getmetadata(vi::VarInfo, vn::VarName) - -Return the metadata in `vi` that belongs to `vn`. -""" -getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) - -""" - getidx(vi::VarInfo, vn::VarName) - -Return the index of `vn` in the metadata of `vi` corresponding to `vn`. -""" -getidx(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] - -""" - getrange(vi::VarInfo, vn::VarName) - -Return the index range of `vn` in the metadata of `vi`. -""" -getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] - -""" - getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) - -Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. -""" -function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getrange(vi, vn), vcat, vns, init=Int[]) -end - -""" - getdist(vi::VarInfo, vn::VarName) - -Return the distribution from which `vn` was sampled in `vi`. -""" -getdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] - -""" - getval(vi::VarInfo, vn::VarName) - -Return the value(s) of `vn`. - -The values may or may not be transformed to Euclidean space. -""" -getval(vi::VarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) - -""" - setval!(vi::VarInfo, val, vn::VarName) - -Set the value(s) of `vn` in the metadata of `vi` to `val`. - -The values may or may not be transformed to Euclidean space. -""" -setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val - -""" - getval(vi::VarInfo, vns::Vector{<:VarName}) - -Return the value(s) of `vns`. - -The values may or may not be transformed to Euclidean space. -""" -function getval(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getval(vi, vn), vcat, vns) -end - -""" - getall(vi::VarInfo) - -Return the values of all the variables in `vi`. - -The values may or may not be transformed to Euclidean space. -""" -getall(vi::UntypedVarInfo) = vi.metadata.vals -getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) -@generated function _getall(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :(metadata.$f.vals)) - end - return :($(exprs...),) -end - -""" - setall!(vi::VarInfo, val) - -Set the values of all the variables in `vi` to `val`. - -The values may or may not be transformed to Euclidean space. -""" -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) -@generated function _setall!(metadata::NamedTuple{names}, val, start = 0) where {names} - expr = Expr(:block) - start = :(1) - for f in names - length = :(length(metadata.$f.vals)) - finish = :($start + $length - 1) - push!(expr.args, :(metadata.$f.vals .= val[$start:$finish])) - start = :($start + $length) - end - return expr -end - -""" - getgid(vi::VarInfo, vn::VarName) - -Return the set of sampler selectors associated with `vn` in `vi`. -""" -getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] - -""" - settrans!(vi::VarInfo, trans::Bool, vn::VarName) - -Set the `trans` flag value of `vn` in `vi`. -""" -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") -end - -""" - syms(vi::VarInfo) - -Returns a tuple of the unique symbols of random variables sampled in `vi`. -""" -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::TypedVarInfo) = keys(vi.metadata) - -# Get all indices of variables belonging to SampleFromPrior: -# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to -# the SampleFromPrior sampler -@inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) - return filter(i -> isempty(vi.metadata.gids[i]) , 1:length(vi.metadata.gids)) -end -# Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol -@inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) - return _getidcs(vi.metadata) -end -@generated function _getidcs(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findinds(metadata.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -# Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::AbstractVarInfo, spl::Sampler) - # NOTE: 0b00 is the sanity flag for - # |\____ getidcs (mask = 0b10) - # \_____ getranges (mask = 0b01) - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - # Checks if cache is valid, i.e. no new pushes were made, to return the cached idcs - # Otherwise, it recomputes the idcs and caches it - #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 - # spl.info[:idcs] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs - #end - return idcs -end -@inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) -@inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) -# Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} - exprs = [] - # Iterate through each varname in metadata. - for f in names - # If the varname is in the sampler space - # or the sample space is empty (all variables) - # then return the indices for that variable. - if inspace(f, space) || length(space) == 0 - push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) - end - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end -@inline function findinds(f_meta, s, ::Val{space}) where {space} - # Get all the idcs of the vns in `space` and that belong to the selector `s` - return filter((i) -> - (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && - (isempty(space) || inspace(f_meta.vns[i], space)), 1:length(f_meta.gids)) -end -@inline function findinds(f_meta) - # Get all the idcs of the vns - return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) -end - -# Get all vns of variables belonging to spl -_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -_getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = _getvns(vi, Selector(), Val(())) -_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space)) -function _getvns(vi::TypedVarInfo, s::Selector, space) - return _getvns(vi.metadata, _getidcs(vi, s, space)) -end -# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol -@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = metadata.$f.vns[idcs.$f])) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -# Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::AbstractVarInfo, spl::Sampler) - ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 - # spl.info[:ranges] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) - #spl.info[:ranges] = ranges - return ranges - #end -end -# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, space) - return _getranges(vi, _getidcs(vi, s, space)) -end -@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) - mapreduce(i -> vi.metadata.ranges[i], vcat, idcs, init=Int[]) -end -@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) - -@generated function _getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end -@inline function findranges(f_ranges, f_idcs) - return mapreduce(i -> f_ranges[i], vcat, f_idcs, init=Int[]) -end - -""" - set_flag!(vi::VarInfo, vn::VarName, flag::String) - -Set `vn`'s value for `flag` to `true` in `vi`. -""" -function set_flag!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true -end - -#### -#### APIs for typed and untyped VarInfo -#### - - -# VarInfo - -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) - -""" - TypedVarInfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function TypedVarInfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New gids, can make a resizeable FillArray - sym_gids = getindex.((meta.gids,), inds) - @assert length(sym_gids) <= 1 || - all(x -> x == sym_gids[1], @view sym_gids[2:end]) - # New orders - sym_orders = getindex.((meta.orders,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = start + 1 : start + length(_vals[i]) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!(new_metas, Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, - sym_dists, sym_gids, sym_orders, sym_flags)) - end - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -TypedVarInfo(vi::TypedVarInfo) = vi - -""" - empty!(vi::VarInfo) - -Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. - -This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. -""" -function empty!(vi::VarInfo) - _empty!(vi.metadata) - resetlogp!(vi) - reset_num_produce!(vi) - return vi -end -@inline _empty!(metadata::Metadata) = empty!(metadata) -@generated function _empty!(metadata::NamedTuple{names}) where {names} - expr = Expr(:block) - for f in names - push!(expr.args, :(empty!(metadata.$f))) - end - return expr -end - -# Functions defined only for UntypedVarInfo -""" - keys(vi::UntypedVarInfo) - -Return an iterator over all `vns` in `vi`. -""" -keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) - -""" - setgid!(vi::VarInfo, gid::Selector, vn::VarName) - -Add `gid` to the set of sampler selectors associated with `vn` in `vi`. -""" -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) - -""" - istrans(vi::VarInfo, vn::VarName) - -Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if -they are in the support of `vn`'s distribution. -""" -istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") - -""" - getlogp(vi::VarInfo) - -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. -""" -getlogp(vi::AbstractVarInfo) = vi.logp[] - -""" - setlogp!(vi::VarInfo, logp) - -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`. -""" -function setlogp!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -""" - acclogp!(vi::VarInfo, logp) - -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`. -""" -function acclogp!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end - -""" - resetlogp!(vi::AbstractVarInfo) - -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0. -""" -resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) - -""" - get_num_produce(vi::VarInfo) - -Return the `num_produce` of `vi`. -""" -get_num_produce(vi::VarInfo) = vi.num_produce[] - -""" - set_num_produce!(vi::VarInfo, n::Int) - -Set the `num_produce` field of `vi` to `n`. -""" -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n - -""" - increment_num_produce!(vi::VarInfo) - -Add 1 to `num_produce` in `vi`. -""" -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 - -""" - reset_num_produce!(vi::AbstractVarInfo) - -Reset the value of `num_produce` the log of the joint probability of the observed data -and parameters sampled in `vi` to 0. -""" -reset_num_produce!(vi::AbstractVarInfo) = set_num_produce!(vi, 0) - -""" - isempty(vi::VarInfo) - -Return true if `vi` is empty and false otherwise. -""" -isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) -isempty(vi::TypedVarInfo) = _isempty(vi.metadata) -@generated function _isempty(metadata::NamedTuple{names}) where {names} - expr = Expr(:&&, :true) - for f in names - push!(expr.args, :(isempty(metadata.$f.idcs))) - end - return expr -end - -# X -> R for all variables associated with given sampler -""" - link!(vi::VarInfo, spl::Sampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the support -of their distributions to the Euclidean space and set their corresponding `"trans"` -flag values to `true`. -""" -function link!(vi::UntypedVarInfo, spl::Sampler) - # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) - if ~istrans(vi, vns[1]) - for vn in vns - dist = getdist(vi, vn) - # TODO: Use inplace versions to avoid allocations - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end -end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, Val(getspace(spl))) -end -@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} - expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!(expr.args, quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end) - end - end - return expr -end - -# R -> X for all variables associated with given sampler -""" - invlink!(vi::VarInfo, spl::AbstractSampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the -Euclidean space back to the support of their distributions and sets their corresponding -`"trans"` flag values to `false`. -""" -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - if istrans(vi, vns[1]) - for vn in vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end -end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, Val(getspace(spl))) -end -@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} - expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!(expr.args, quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end) - end - end - return expr -end - - -""" - islinked(vi::VarInfo, spl::Sampler) - -Check whether `vi` is in the transformed space for a particular sampler `spl`. - -Turing's Hamiltonian samplers use the `link` and `invlink` functions from -[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `islinked` checks if the number is in the constrained space or the real space. -""" -function islinked(vi::UntypedVarInfo, spl::Sampler) - vns = _getvns(vi, spl) - return istrans(vi, vns[1]) -end -function islinked(vi::TypedVarInfo, spl::Sampler) - vns = _getvns(vi, spl) - return _islinked(vi, vns) -end -@generated function _islinked(vi, vns::NamedTuple{names}) where {names} - out = [] - for f in names - push!(out, :(length(vns.$f) == 0 ? false : istrans(vi, vns.$f[1]))) - end - return Expr(:||, false, out...) -end - -# The default getindex & setindex!() for get & set values -# NOTE: vi[vn] will always transform the variable to its original space and Julia type -""" - getindex(vi::VarInfo, vn::VarName) - getindex(vi::VarInfo, vns::Vector{<:VarName}) - -Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) -distribution(s). - -If the value(s) is (are) transformed to the Euclidean space, it is -(they are) transformed back. -""" -function getindex(vi::AbstractVarInfo, vn::VarName) - @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vn) - return istrans(vi, vn) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) : - reconstruct(dist, getval(vi, vn)) -end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vns[1]) - return istrans(vi, vns[1]) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) : - reconstruct(dist, getval(vi, vns), length(vns)) -end - -""" - getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) - -Return the current value(s) of the random variables sampled by `spl` in `vi`. - -The value(s) may or may not be transformed to Euclidean space. -""" -getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi)) -getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi)) -getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) -function getindex(vi::TypedVarInfo, spl::Sampler) - # Gets the ranges as a NamedTuple - ranges = _getranges(vi, spl) - # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return vcat(_getindex(vi.metadata, ranges)...) -end -# Recursively builds a tuple of the `vals` of all the symbols -@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) - end - return expr -end - -""" - setindex!(vi::VarInfo, val, vn::VarName) - -Set the current value(s) of the random variable `vn` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) - -""" - setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) - -Set the current value(s) of the random variables sampled by `spl` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) -setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) -function setindex!(vi::TypedVarInfo, val, spl::Sampler) - # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` - ranges = _getranges(vi, spl) - _setindex!(vi.metadata, val, ranges) - return val -end -# Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. -@generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} - expr = Expr(:block) - offset = :(0) - for f in names - f_vals = :(metadata.$f.vals) - f_range = :(ranges.$f) - start = :($offset + 1) - len = :(length($f_range)) - finish = :($offset + $len) - push!(expr.args, :(@views $f_vals[$f_range] .= val[$start:$finish])) - offset = :($offset + $len) - end - return expr -end - -""" - tonamedtuple(vi::VarInfo) - -Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and -indexing string of the variable. - -For example, a model that had a vector of vector-valued -variables `x` would return - -```julia -(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) -``` -""" -function tonamedtuple(vi::VarInfo) - return tonamedtuple(vi.metadata, vi) -end -@generated function tonamedtuple(metadata::NamedTuple{names}, vi::VarInfo) where {names} - length(names) === 0 && return :(NamedTuple()) - expr = Expr(:tuple) - map(names) do f - push!(expr.args, Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns)))) - end - return expr -end - -@inline function findvns(vi, f_vns) - if length(f_vns) == 0 - throw("Unidentified error, please report this error in an issue.") - end - return map(vn -> vi[vn], f_vns) -end - -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler, SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) -end - -""" - haskey(vi::VarInfo, vn::VarName) - -Check whether `vn` has been sampled in `vi`. -""" -haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn).idcs, vn) -function haskey(vi::TypedVarInfo, vn::VarName) - metadata = vi.metadata - Tmeta = typeof(metadata) - return getsym(vn) in fieldnames(Tmeta) && haskey(getmetadata(vi, vn).idcs, vn) -end - -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | GIDs : $(vi.metadata.gids) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ - print(io, vi_str) -end - - -const _MAX_VARS_SHOWN = 4 - -function _show_varnames(io::IO, vi) - md = vi.metadata - vns = md.vns - - groups = Dict{Symbol, Vector{VarName}}() - for vn in vns - group = get!(() -> Vector{VarName}(), groups, getsym(vn)) - push!(group, vn) - end - - print(io, length(groups), length(groups) == 1 ? " variable " : " variables ", "(") - join(io, Iterators.take(keys(groups), _MAX_VARS_SHOWN), ", ") - length(groups) > _MAX_VARS_SHOWN && print(io, ", ...") - print(io, "), dimension ", sum(prod(size(md.vals[md.ranges[md.idcs[vn]]])) for vn in vns)) -end - -function Base.show(io::IO, vi::UntypedVarInfo) - print(io, "VarInfo (") - _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi), digits=3)) - print(io, ")") -end - - -""" - push!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`. -""" -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!(vi, vn, r, dist, Set{Selector}([])) -end - -""" - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`. - -The sampler is passed here to invalidate its cache where defined. -""" -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - spl.info[:cache_updated] = CACHERESET - return push!(vi, vn, r, dist, spl.selector) -end -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - return push!(vi, vn, r, dist) -end - -""" - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler of -selector `gid` from a distribution `dist` to `VarInfo` `vi`. -""" -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!(vi, vn, r, dist, Set([gid])) -end -function push!( - vi::VarInfo, - vn::VarName, - r, - dist::Distribution, - gidset::Set{Selector} - ) - - if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" - elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" - end - - val = vectorize(dist, r) - - meta = getmetadata(vi, vn) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals); n = length(val) - push!(meta.ranges, l+1:l+n) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.gids, gidset) - push!(meta.orders, get_num_produce(vi)) - push!(meta.flags["del"], false) - push!(meta.flags["trans"], false) - - return vi -end - -""" - setorder!(vi::VarInfo, vn::VarName, index::Int) - -Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe -statements run before sampling `vn`. -""" -function setorder!(vi::VarInfo, vn::VarName, index::Int) - metadata = getmetadata(vi, vn) - if metadata.orders[metadata.idcs[vn]] != index - metadata.orders[metadata.idcs[vn]] = index - end - return vi -end - -####################################### -# Rand & replaying method for VarInfo # -####################################### - -""" - is_flagged(vi::VarInfo, vn::VarName, flag::String) - -Check whether `vn` has a true value for `flag` in `vi`. -""" -function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] -end - -""" - unset_flag!(vi::VarInfo, vn::VarName, flag::String) - -Set `vn`'s value for `flag` to `false` in `vi`. -""" -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false -end - -""" - set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) - -Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. -""" -function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) - if get_num_produce(vi) == 0 - for i = length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true - end - else - for i in 1:length(vi.orders) - if i in gidcs && vi.orders[i] > get_num_produce(vi) - vi.metadata.flags["del"][i] = true - end - end - end - return nothing -end -function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) - return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) -end -@generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names} - expr = Expr(:block) - for f in names - f_gidcs = :(gidcs.$f) - f_orders = :(metadata.$f.orders) - f_flags = :(metadata.$f.flags) - push!(expr.args, quote - # Set the flag for variables with symbol `f` - if num_produce == 0 - for i = length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true - end - else - for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce - $f_flags["del"][i] = true - end - end - end - end) - end - return expr -end - -""" - updategid!(vi::VarInfo, vn::VarName, spl::Sampler) - -Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked -and `vn`'s symbol is in the space of `spl`. -""" -function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) - if inspace(vn, getspace(spl)) - setgid!(vi, spl.selector, vn) - end -end diff --git a/src/varinfo/ad.jl b/src/varinfo/ad.jl new file mode 100644 index 000000000..e3939bdf3 --- /dev/null +++ b/src/varinfo/ad.jl @@ -0,0 +1,19 @@ +function __init__() + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + value(x::ForwardDiff.Dual) = ForwardDiff.value(x) + value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) + value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) + value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) + end + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + value(x::Tracker.TrackedReal) = Tracker.data(x) + value(x::Tracker.TrackedArray) = Tracker.data(x) + value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) + end +end +ZygoteRules.@adjoint function zygote_setval!(vi, val, vn) + return zygote_setval!(vi, val, vn), _ -> nothing +end diff --git a/src/varinfo/indexing.jl b/src/varinfo/indexing.jl new file mode 100644 index 000000000..03227e820 --- /dev/null +++ b/src/varinfo/indexing.jl @@ -0,0 +1,394 @@ +## Vectorized value getters and setters ## + +const VarView = Union{Int, UnitRange, Vector{Int}} + +""" + getval(vi::UntypedVarInfo, vview::Union{Int, UnitRange, Vector{Int}}) + +Return a view `vi.vals[vview]`. +""" +function getval(vi::UntypedVarInfo, vview::VarView) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return view(vals, vview) +end + +""" + setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) + +Set the value of `vi.vals[vview]` to `val`. +""" +function setval!(vi::UntypedVarInfo, val, vview::VarView) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return vals[vview] = val +end +function setval!(vi::UntypedVarInfo, val, vview::Vector{UnitRange}) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + if length(vview) > 0 + vals[[i for arr in vview for i in arr]] = val + end + return val +end + +""" + getval(vi::AbstractVarInfo, vn::VarName) + +Return the value(s) of `vn`. + +The values may or may not be transformed to Euclidean space. +""" +function getval(vi::AbstractVarInfo, vn::VarName) + metadata = getmetadata(vi, vn) + vals = getmode(vi) isa LinkMode ? metadata.trans_vals : metadata.vals + return view(vals, getrange(vi, vn)) +end + +""" + setval!(vi::AbstractVarInfo, val, vn::VarName) + +Set the value(s) of `vn` in the metadata of `vi` to `val`. + +The values may or may not be transformed to Euclidean space. +""" +function setval!(vi::AbstractVarInfo, val, vn::VarName) + metadata = getmetadata(vi, vn) + vals = getmode(vi) isa LinkMode ? metadata.trans_vals : metadata.vals + return vals[getrange(vi, vn)] = val +end + +""" + getval(vi::VarInfo, vns::Vector{<:VarName}) + +Return the value(s) of `vns`. + +The values may or may not be transformed to Euclidean space. +""" +function getval(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return mapreduce(vn -> getval(vi, vn), vcat, vns) +end + +""" + getall(vi::VarInfo) + +Return the values of all the variables in `vi`. + +The values may or may not be transformed to Euclidean space. +""" +function getall(vi::UntypedVarInfo) + return getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals +end +function getall(vi::TypedVarInfo) + return vcat(_getall(vi.metadata, Val(getmode(vi) isa LinkMode))...) +end +@generated function _getall(metadata::NamedTuple{names}, ::Val{linked}) where {names, linked} + exprs = [] + for f in names + if linked + push!(exprs, :(metadata.$f.trans_vals)) + else + push!(exprs, :(metadata.$f.vals)) + end + end + return :($(exprs...),) +end + +""" + setall!(vi::VarInfo, val) + +Set the values of all the variables in `vi` to `val`. + +The values may or may not be transformed to Euclidean space. +""" +function setall!(vi::UntypedVarInfo, val) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return vals .= val +end +setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val, Val(getmode(vi) isa LinkMode)) +@generated function _setall!(metadata::NamedTuple{names}, val, ::Val{true}, start = 0) where {names} + expr = Expr(:block) + start = :(1) + for f in names + length = :(length(metadata.$f.trans_vals)) + finish = :($start + $length - 1) + push!(expr.args, :(metadata.$f.trans_vals .= val[$start:$finish])) + start = :($start + $length) + end + return expr +end +@generated function _setall!(metadata::NamedTuple{names}, val, ::Val{false}, start = 0) where {names} + expr = Expr(:block) + start = :(1) + for f in names + length = :(length(metadata.$f.vals)) + finish = :($start + $length - 1) + push!(expr.args, :(metadata.$f.vals .= val[$start:$finish])) + start = :($start + $length) + end + return expr +end + +## VarName getindex and setindex! ## + +function zygote_setval!(vi, val, vn) + return setval!(vi, val, vn) +end + +""" + getindex(vi::VarInfo, vn::VarName, dist::Distribution) + getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) + +Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) +distribution(s) `dist`. + +If the value(s) is (are) transformed to the Euclidean space, it is +(they are) transformed back. +""" +function Base.getindex( + vi::AbstractVarInfo, + vn::VarName, + dist::Distribution, +) + @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" + trans = istrans(vi, vn) + if has_fixed_support(vi) + set_fixed_support!(vi, bijector(dist) == bijector(getinitdist(vi, vn))) + end + if getmode(vi) isa LinkMode && trans + trans_val = reconstruct(dist, getval(vi, vn)) + val = Bijectors.invlink(dist, trans_val) + zygote_setval!(invlink(vi), value(vectorize(dist, val)), vn) + elseif getmode(vi) isa InitLinkMode && trans + val = reconstruct(dist, getval(vi, vn)) + trans_val = Bijectors.link(dist, val) + zygote_setval!(link(vi), vectorize(dist, trans_val), vn) + else + val = reconstruct(dist, getval(invlink(vi), vn)) + end + return val +end +function Base.getindex( + vi::AbstractVarInfo, + vn::VarName, +) + @assert getmode(vi) isa StandardMode + @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" + return reconstruct(getinitdist(vi, vn), getval(vi, vn)) +end +function Base.getindex( + vi::AbstractVarInfo, + vns::AbstractVector{<:VarName}, + dist::MultivariateDistribution, +) + return mapreduce(hcat, vns) do vn + vi[vn, dist] + end +end +function Base.getindex( + vi::AbstractVarInfo, + vns::AbstractArray{<:VarName}, + dists::Union{Distribution, AbstractArray{<:Distribution}}, +) + return broadcast(vns, dists) do vn, dist + vi[vn, dist] + end +end +function Base.getindex( + vi::AbstractVarInfo, + vns::Vector{<:VarName}, +) + return map(vns) do vn + vi[vn] + end +end + +""" + setindex!(vi::VarInfo, val, vn::VarName) + +Set the current value(s) of the random variable `vn` in `vi` to `val`. + +The value(s) may or may not be transformed to Euclidean space. +""" +function setindex!(vi::AbstractVarInfo, val, vn::VarName, dist::Distribution) + @assert haskey(vi, vn) "[DynamicPPL] variable not found in VarInfo." + trans = istrans(vi, vn) + if getmode(vi) isa LinkMode && trans + trans_val = Bijectors.link(dist, val) + setval!(vi, vectorize(dist, trans_val), vn) + setval!(invlink(vi), vectorize(dist, val), vn) + elseif getmode(vi) isa InitLinkMode && trans + trans_val = Bijectors.link(dist, val) + setval!(vi, vectorize(dist, val), vn) + setval!(link(vi), vectorize(dist, trans_val), vn) + else + setval!(invlink(vi), vectorize(dist, val), vn) + end + return vi +end +function setindex!(vi::AbstractVarInfo, val, vn::VarName) + @assert getmode(vi) isa StandardMode + @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" + setval!(vi, vectorize(getinitdist(vi, vn), val), vn) + return vi +end + +## Sampler getindex and setindex! ## + +""" + getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) + +Return the current value(s) of the random variables sampled by `spl` in `vi`. + +The value(s) may or may not be transformed to Euclidean space. +""" +function getindex(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) + return copy.(getall(vi)) +end +function getindex(vi::UntypedVarInfo, spl::Sampler) + return copy.(getval(vi, getranges(vi, spl))) +end +function getindex(vi::TypedVarInfo, spl::Sampler) + # Gets the ranges as a NamedTuple + ranges = getranges(vi, spl) + # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` + return vcat(_getindex(vi.metadata, ranges, Val(getmode(vi) isa LinkMode))...) +end +# Recursively builds a tuple of the `vals` of all the symbols +@generated function _getindex( + metadata, + ranges::NamedTuple{names}, + ::Val{false}, +) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :(metadata.$f.vals[ranges.$f])) + end + return expr +end +@generated function _getindex( + metadata, + ranges::NamedTuple{names}, + ::Val{true}, +) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :(metadata.$f.trans_vals[ranges.$f])) + end + return expr +end + +""" + setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) + +Set the current value(s) of the random variables sampled by `spl` in `vi` to `val`. + +The value(s) may or may not be transformed to Euclidean space. +""" +function setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) + setall!(vi, val) + setsynced!(vi, false) + return vi +end +function setindex!(vi::UntypedVarInfo, val, spl::Sampler) + setval!(vi, val, getranges(vi, spl)) + setsynced!(vi, false) + return vi +end +function setindex!(vi::TypedVarInfo, val, spl::Sampler) + # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` + ranges = getranges(vi, spl) + _setindex!(vi.metadata, val, ranges, Val(getmode(vi) isa LinkMode)) + setsynced!(vi, false) + return vi +end +# Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. +@generated function _setindex!( + metadata, + val, + ranges::NamedTuple{names}, + ::Val{linked}, +) where {names, linked} + expr = Expr(:block) + offset = :(0) + for f in names + f_vals = linked ? :(metadata.$f.trans_vals) : :(metadata.$f.vals) + f_range = :(ranges.$f) + start = :($offset + 1) + len = :(length($f_range)) + finish = :($offset + $len) + push!(expr.args, :(@views $f_vals[$f_range] .= val[$start:$finish])) + offset = :($offset + $len) + end + return expr +end + +""" + push!(vi::VarInfo, vn::VarName, r, dist::Distribution) + +Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to +the `VarInfo` `vi`. +""" +function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return push!(vi, vn, r, dist, Set{Selector}([])) +end + +""" + push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` +from a distribution `dist` to `VarInfo` `vi`. + +The sampler is passed here to invalidate its cache where defined. +""" +function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) + spl.info[:cache_updated] = CACHERESET + return push!(vi, vn, r, dist, spl.selector) +end +function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + return push!(vi, vn, r, dist) +end + +""" + push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler of +selector `gid` from a distribution `dist` to `VarInfo` `vi`. +""" +function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + return push!(vi, vn, r, dist, Set([gid])) +end +function push!( + vi::VarInfo, + vn::VarName, + val, + dist::Distribution, + gidset::Set{Selector}, +) + if vi isa UntypedVarInfo + @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + elseif vi isa TypedVarInfo + @assert ~(haskey(vi, vn)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + end + + meta = getmetadata(vi, vn) + meta.idcs[vn] = length(meta.idcs) + 1 + push!(meta.vns, vn) + + vectorized_val = vectorize(dist, val) + l = length(meta.vals); n = length(vectorized_val) + push!(meta.ranges, l+1:l+n) + if getmode(vi) isa LinkMode || getmode(vi) isa InitLinkMode + append!(meta.vals, vectorized_val) + trans_val = Bijectors.link(dist, val) + append!(meta.trans_vals, vectorize(dist, trans_val)) + else + append!(meta.vals, vectorized_val) + append!(meta.trans_vals, vectorized_val) + setsynced!(vi, false) + end + push!(meta.dists, dist) + push!(meta.gids, gidset) + push!(meta.orders, get_num_produce(vi)) + push!(meta.flags["del"], false) + push!(meta.flags["trans"], false) + + return vi +end diff --git a/src/varinfo/linking.jl b/src/varinfo/linking.jl new file mode 100644 index 000000000..73f7b4b1a --- /dev/null +++ b/src/varinfo/linking.jl @@ -0,0 +1,207 @@ +""" + islinked(vi::VarInfo, spl::Sampler) + +Check whether `vi` is in the transformed space for a particular sampler `spl`. + +Turing's Hamiltonian samplers use the `link` and `invlink` functions from +[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable +(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of +real numbers. `islinked` checks if the number is in the constrained space or the real space. +""" +function islinked(vi::UntypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + return islinked(vi) && istrans(vi, vns[1]) +end +function islinked(vi::TypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + return islinked(vi) && _islinked(vi, vns) +end +@generated function _islinked(vi, vns::NamedTuple{names}) where {names} + out = [] + for f in names + push!(out, :(length(vns.$f) == 0 ? false : istrans(vi, vns.$f[1]))) + end + return Expr(:||, false, out...) +end +function islinked_and_trans(vi::AbstractVarInfo, vn::VarName) + return islinked(vi) && istrans(vi, vn) +end + +function Bijectors.link(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + LinkMode(), + vi.fixed_support, + vi.synced, + ) +end +function initlink(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + InitLinkMode(), + vi.fixed_support, + vi.synced, + ) +end +function Bijectors.invlink(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + StandardMode(), + vi.fixed_support, + vi.synced, + ) +end +islinked(vi::AbstractVarInfo) = getmode(vi) isa LinkMode || getmode(vi) isa InitLinkMode + +# X -> R for all variables associated with given sampler +""" + init_dist_link!(vi::VarInfo, spl::Sampler) +Transform the values of the random variables sampled by `spl` in `vi` from the support +of their distributions to the Euclidean space and set their corresponding `"trans"` +flag values to `true`. +""" +function init_dist_link!(vi::UntypedVarInfo, spl::Sampler) + # TODO: Change to a lazy iterator over `vns` + vns = getvns(vi, spl) + for vn in vns + dist = getinitdist(vi, vn) + initlink(vi)[vn, dist] + end + return vi +end +function init_dist_link!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _init_dist_link!(vi.metadata, vi, vns, Val(getspace(spl))) + return vi +end +@generated function _init_dist_link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} + expr = Expr(:block) + for f in names + if inspace(f, space) || length(space) == 0 + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + # Iterate over all `f_vns` and transform + for vn in f_vns + dist = getinitdist(vi, vn) + initlink(vi)[vn, dist] + end + end) + end + end + return expr +end + +function invlink!(vi::AbstractVarInfo, spl::AbstractSampler, model) + settrans!(vi, spl) + if !issynced(vi) + if has_fixed_support(vi) + init_dist_invlink!(vi, spl) + else + model(link(vi), spl) + end + setsynced!(vi, true) + end + return vi +end +function link!(vi::AbstractVarInfo, spl::AbstractSampler, model) + settrans!(vi, spl) + if !issynced(vi) + if has_fixed_support(vi) + init_dist_link!(vi, spl) + else + model(initlink(vi), spl) + end + setsynced!(vi, true) + end + return vi +end + +# R -> X for all variables associated with given sampler +""" + init_dist_invlink!(vi::VarInfo, spl::AbstractSampler) +Transform the values of the random variables sampled by `spl` in `vi` from the +Euclidean space back to the support of their distributions and sets their corresponding +`"trans"` flag values to `false`. +""" +function init_dist_invlink!(vi::UntypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + for vn in vns + dist = getinitdist(vi, vn) + link(vi)[vn, dist] + end + return vi +end +function init_dist_invlink!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _init_dist_invlink!(vi.metadata, vi, vns, Val(getspace(spl))) + return vi +end +@generated function _init_dist_invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} + expr = Expr(:block) + for f in names + if inspace(f, space) || length(space) == 0 + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + # Iterate over all `f_vns` and transform + for vn in f_vns + dist = getinitdist(vi, vn) + link(vi)[vn, dist] + end + end) + end + end + return expr +end + +# X -> R for all variables associated with given sampler +""" + settrans!(vi::VarInfo, spl::Sampler) + +Set the `"trans"` flag to `true` for all the vaiables in the space of `spl`. +""" +function settrans!(vi::UntypedVarInfo, spl::Sampler) + # TODO: Change to a lazy iterator over `vns` + vns = getvns(vi, spl) + if ~istrans(vi, vns[1]) + for vn in vns + settrans!(vi, true, vn) + end + end + return vi +end +function settrans!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _settrans!(vi, vns, Val(getspace(spl))) + return vi +end +@generated function _settrans!(vi, ::NamedTuple{names}, ::Val{space}) where {names, space} + expr = Expr(:block) + for f in names + if inspace(f, space) || length(space) == 0 + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + if length(f_vns) > 0 && ~istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + settrans!(vi, true, vn) + end + end + end) + end + end + return expr +end + +""" + settrans!(vi::VarInfo, trans::Bool, vn::VarName) +Set the `trans` flag value of `vn` in `vi`. +""" +function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +end diff --git a/src/varinfo/types.jl b/src/varinfo/types.jl new file mode 100644 index 000000000..a32ac304d --- /dev/null +++ b/src/varinfo/types.jl @@ -0,0 +1,349 @@ +#################### +# VarInfo metadata # +#################### + +""" +The `Metadata` struct stores some metadata about the parameters of the model. This helps +query certain information about a variable, such as its distribution, which samplers +sample this variable, its value and whether this value is transformed to real space or +not. + +Let `md` be an instance of `Metadata`: +- `md.vns` is the vector of all `VarName` instances. +- `md.idcs` is the dictionary that maps each `VarName` instance to its index in + `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. +- `md.vns[md.idcs[vn]] == vn`. +- `md.dists[md.idcs[vn]]` is the distribution of `vn`. +- `md.gids[md.idcs[vn]]` is the set of algorithms used to sample `vn`. This is used in + the Gibbs sampling process. +- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. +- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. +- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. +- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the + value of `flag` corresponding to `vn`. + +To make `md::Metadata` type stable, all the `md.vns` must have the same symbol +and distribution type. However, one can have a Julia variable, say `x`, that is a +matrix or a hierarchical array sampled in partitions, e.g. +`x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)`, and is managed by +a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the +same type. Type unstable `Metadata` will still work but will have inferior performance. +When sampling, the first iteration uses a type unstable `Metadata` for all the +variables then a specialized `Metadata` is used for each symbol along with a function +barrier to make the rest of the sampling type stable. +""" +struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TTransVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} + # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` + idcs :: TIdcs # Dict{<:VarName,Int} + + # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` + vns :: TVN # AbstractVector{<:VarName} + + # Vector of index ranges in `vals` corresponding to `vns` + # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` + ranges :: Vector{UnitRange{Int}} + + # Vector of values of all the univariate, multivariate and matrix variables + # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` + vals :: TVal # AbstractVector{<:Real} + + # Vector of the transformed values of all the univariate, multivariate and matrix + # variablse. The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` + trans_vals :: TTransVal # AbstractVector{<:Real} + + # Vector of distributions correpsonding to `vns` + dists :: TDists # AbstractVector{<:Distribution} + + # Vector of sampler ids corresponding to `vns` + # Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` + gids :: TGIds # AbstractVector{Set{Selector}} + + # Number of `observe` statements before each random variable is sampled + orders :: Vector{Int} + + # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` + flags :: Dict{String, BitVector} +end + +""" + Metadata() + +Construct an empty type unstable instance of `Metadata`. +""" +function Metadata() + vals = Vector{Real}() + trans_vals = Vector{Real}() + flags = Dict{String, BitVector}() + flags["del"] = BitVector() + flags["trans"] = BitVector() + + return Metadata( + Dict{VarName, Int}(), + Vector{VarName}(), + Vector{UnitRange{Int}}(), + vals, + trans_vals, + Vector{Distribution}(), + Vector{Set{Selector}}(), + Vector{Int}(), + flags + ) +end + +########### +# VarInfo # +########### + +abstract type VarInfoMode end + +""" + LinkMode + +For any random variable whose `"trans"` flag is set to `true`: +1. The transformed values are used in `getindex` and `setindex!`. +2. The untransformed values are computed and cached, and +3. The `logpdf_with_trans` is computed with `trans` set as `true`. + +For random variables whose `"trans"` flag is set to `false`, this is equivalent to +the `StandardMode`. This model can be used when running HMC or MAP in the +unconstrained space. +""" +struct LinkMode <: VarInfoMode end + +""" + InitLinkMode + +For any random variable whose `"trans"` flag is set to `true`: +1. The untransformed values are used in `getindex` and `setindex!`. +2. The transformed values are computed and cached, and +3. The `logpdf_with_trans` is computed with `trans` set as `true`. + +For random variables whose `"trans"` flag is set to `false`, this is equivalent to +the `StandardMode`. This mode can be used to initialize a `VarInfo` for HMC or MAP. +""" +struct InitLinkMode <: VarInfoMode end + +""" + StandardMode + +For all random variables: +1. The untransformed values are used in `getindex` and `setindex!`. +2. The `logpdf` is computed, ie. `logpdf_with_trans` with `trans` as `false`. + +This mode can be used when running non-HMC samplers or when doing MAP on the +constrained support directly. +""" +struct StandardMode <: VarInfoMode end + +""" +``` +struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} +end +``` + +A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of +`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used +for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If +`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each +symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows +for the type specialization of `vi` after the first sampling iteration when all the +symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. + +Note: It is the user's responsibility to ensure that each "symbol" is visited at least +once whenever the model is called, regardless of any stochastic branching. Each symbol +refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. +""" +struct VarInfo{Tmeta <: Union{Metadata, NamedTuple}, Tlogp, Tmode <: VarInfoMode} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} + mode::Tmode + fixed_support::Base.RefValue{Bool} + synced::Base.RefValue{Bool} +end +const UntypedVarInfo = VarInfo{<:Metadata} +const TypedVarInfo = VarInfo{<:NamedTuple} + +function TypedVarInfo(model::Model, ctx = DefaultContext()) + vi = VarInfo() + model(vi, SampleFromPrior(), ctx) + return TypedVarInfo(vi) +end +function TypedVarInfo(model::Model, n::Integer, ctx = DefaultContext()) + return mapreduce(merge, 1:n) do _ + vi = VarInfo() + model(vi, SampleFromPrior(), ctx) + TypedVarInfo(vi) + end +end +function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) + new_vi = deepcopy(old_vi) + new_vi[spl] = x + return new_vi +end +function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x, Val(getmode(old_vi) isa LinkMode)) + return VarInfo( + md, + Base.RefValue{eltype(x)}(getlogp(old_vi)), + Ref(get_num_produce(old_vi)), + old_vi.mode, + old_vi.fixed_support, + Ref(false), + ) +end +@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x, ::Val{islinked}) where {names, space, islinked} + exprs = [] + offset = :(0) + for f in names + mdf = :(metadata.$f) + if inspace(f, space) || length(space) == 0 + len = :(length($mdf.vals)) + if islinked + push!(exprs, :($f = Metadata($mdf.idcs, + $mdf.vns, + $mdf.ranges, + $mdf.vals, + x[($offset + 1):($offset + $len)], + $mdf.dists, + $mdf.gids, + $mdf.orders, + $mdf.flags + ) + ) + ) + else + push!(exprs, :($f = Metadata($mdf.idcs, + $mdf.vns, + $mdf.ranges, + x[($offset + 1):($offset + $len)], + $mdf.trans_vals, + $mdf.dists, + $mdf.gids, + $mdf.orders, + $mdf.flags + ) + ) + ) + end + offset = :($offset + $len) + else + push!(exprs, :($f = $mdf)) + end + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end + +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Real}(0.0), Ref(0), StandardMode(), Ref(true), Ref(false)) + +""" + TypedVarInfo(vi::UntypedVarInfo) + +This function finds all the unique `sym`s from the instances of `VarName{sym}` found in +`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the +global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as +a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each +symbol. +""" +function TypedVarInfo(vi::UntypedVarInfo) + meta = vi.metadata + new_metas = Metadata[] + # Symbols of all instances of `VarName{sym}` in `vi.vns` + syms_tuple = Tuple(syms(vi)) + for s in syms_tuple + # Find all indices in `vns` with symbol `s` + inds = findall(vn -> getsym(vn) === s, meta.vns) + n = length(inds) + # New `vns` + sym_vns = getindex.((meta.vns,), inds) + # New idcs + sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) + # New dists + sym_dists = getindex.((meta.dists,), inds) + # New gids, can make a resizeable FillArray + sym_gids = getindex.((meta.gids,), inds) + # New orders + sym_orders = getindex.((meta.orders,), inds) + # New flags + sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + + # Extract new ranges and vals + _ranges = getindex.((meta.ranges,), inds) + # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 + _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + _trans_vals = [copy.(meta.trans_vals[_ranges[i]]) for i in 1:n] + sym_ranges = Vector{eltype(_ranges)}(undef, n) + start = 0 + for i in 1:n + sym_ranges[i] = start + 1 : start + length(_vals[i]) + start += length(_vals[i]) + end + sym_vals = foldl(vcat, _vals) + sym_trans_vals = foldl(vcat, _trans_vals) + + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_trans_vals, + sym_dists, sym_gids, sym_orders, sym_flags + ) + ) + end + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple{syms_tuple}(Tuple(new_metas)) + return VarInfo(nt, Ref(logp), Ref(num_produce), vi.mode, vi.fixed_support, vi.synced) +end +TypedVarInfo(vi::TypedVarInfo) = vi + + +#### +#### Printing +#### + +function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) + vi_str = """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + | Varnames : $(string(vi.metadata.vns)) + | Range : $(vi.metadata.ranges) + | Vals : $(vi.metadata.vals) + | GIDs : $(vi.metadata.gids) + | Orders : $(vi.metadata.orders) + | Logp : $(getlogp(vi)) + | #produce : $(get_num_produce(vi)) + | flags : $(vi.metadata.flags) + \\======================================================================= + """ + print(io, vi_str) +end + +const _MAX_VARS_SHOWN = 4 + +function _show_varnames(io::IO, vi) + md = vi.metadata + vns = md.vns + + groups = Dict{Symbol, Vector{VarName}}() + for vn in vns + group = get!(() -> Vector{VarName}(), groups, getsym(vn)) + push!(group, vn) + end + + print(io, length(groups), length(groups) == 1 ? " variable " : " variables ", "(") + join(io, Iterators.take(keys(groups), _MAX_VARS_SHOWN), ", ") + length(groups) > _MAX_VARS_SHOWN && print(io, ", ...") + print(io, "), dimension ", sum(prod(size(md.vals[md.ranges[md.idcs[vn]]])) for vn in vns)) +end + +function Base.show(io::IO, vi::UntypedVarInfo) + print(io, "VarInfo (") + _show_varnames(io, vi) + print(io, "; logp: ", round(getlogp(vi), digits=3)) + print(io, ")") +end diff --git a/src/varinfo/utils.jl b/src/varinfo/utils.jl new file mode 100644 index 000000000..dbc35a261 --- /dev/null +++ b/src/varinfo/utils.jl @@ -0,0 +1,640 @@ +has_fixed_support(vi::VarInfo) = vi.fixed_support[] +function set_fixed_support!(vi::VarInfo, b::Bool) + return vi.fixed_support[] = vi.fixed_support[] && b +end + +getmode(vi::VarInfo) = vi.mode +issynced(vi::VarInfo) = vi.synced[] +setsynced!(vi::VarInfo, b::Bool) = vi.synced[] = b +value(x) = x +getinferred(vi::TypedVarInfo) = keys(vi.metadata) +getinferred(::UntypedVarInfo) = () + +Base.merge(t::AbstractVarInfo) = t +function Base.merge( + t1::AbstractVarInfo, + t2::AbstractVarInfo, + ts::AbstractVarInfo..., +) + return merge(merge(t1, t2), ts...) +end +function Base.merge(t1::TypedVarInfo, t2::TypedVarInfo) + return VarInfo( + merge(t1.metadata, t2.metadata), + Ref(getlogp(t1) + getlogp(t2)), + Ref(0), + getmode(t1), + Ref(has_fixed_support(t1) && has_fixed_support(t2)), + Ref(issynced(t1) && issynced(t2)), + ) +end +Base.merge(t1::UntypedVarInfo, t2::TypedVarInfo) = merge(TypedVarInfo(t1), t2) +Base.merge(t1::TypedVarInfo, t2::UntypedVarInfo) = merge(t2, t1) +Base.merge(t1::UntypedVarInfo, ::UntypedVarInfo) = t1 + +""" + empty!(meta::Metadata) + +Empty the fields of `meta`. + +This is useful when using a sampling algorithm that assumes an empty `meta`, e.g. `SMC`. +""" +function empty!(meta::Metadata) + empty!(meta.idcs) + empty!(meta.vns) + empty!(meta.ranges) + empty!(meta.vals) + empty!(meta.trans_vals) + empty!(meta.dists) + empty!(meta.gids) + empty!(meta.orders) + for k in keys(meta.flags) + empty!(meta.flags[k]) + end + + return meta +end + +""" + empty!(vi::VarInfo) + +Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to +zeros. + +This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. +""" +function empty!(vi::VarInfo) + _empty!(vi.metadata) + resetlogp!(vi) + reset_num_produce!(vi) + setsynced!(vi, false) + return vi +end +@inline _empty!(metadata::Metadata) = empty!(metadata) +@generated function _empty!(metadata::NamedTuple{names}) where {names} + expr = Expr(:block) + for f in names + push!(expr.args, :(empty!(metadata.$f))) + end + return expr +end + +# Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. +if VERSION < v"1.1" + _tail(nt::NamedTuple{names}) where names = NamedTuple{Base.tail(names)}(nt) +else + _tail(nt::NamedTuple) = Base.tail(nt) +end + +""" + getmetadata(vi::VarInfo, vn::VarName) + +Return the metadata in `vi` that belongs to `vn`. +""" +getmetadata(vi::VarInfo, vn::VarName) = vi.metadata +getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) + +""" + getidx(vi::AbstractVarInfo, vn::VarName) + +Return the index of `vn` in the metadata of `vi` corresponding to `vn`. +""" +getidx(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] + +""" + getrange(vi::AbstractVarInfo, vn::VarName) + +Return the index range of `vn` in the metadata of `vi`. +""" +getrange(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] + +""" + getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) + +Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. +""" +function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return mapreduce(vn -> getrange(vi, vn), vcat, vns, init=Int[]) +end + +""" + getinitdist(vi::AbstractVarInfo, vn::VarName) + +Return the distribution from which `vn` was sampled in `vi`. +""" +getinitdist(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] + +""" + getgid(vi::AbstractVarInfo, vn::VarName) + +Return the set of sampler selectors associated with `vn` in `vi`. +""" +getgid(vi::AbstractVarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] + +""" + syms(vi::VarInfo) + +Returns a tuple of the unique symbols of random variables sampled in `vi`. +""" +syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols +syms(vi::TypedVarInfo) = keys(vi.metadata) + +# Get all indices of variables belonging to SampleFromPrior: +# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to +# the SampleFromPrior sampler +@inline function getidcs(vi::UntypedVarInfo, ::SampleFromPrior) + return filter(i -> isempty(vi.metadata.gids[i]) , 1:length(vi.metadata.gids)) +end +# Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol +@inline function getidcs(vi::TypedVarInfo, ::SampleFromPrior) + return getidcs(vi.metadata) +end +@generated function getidcs(metadata::NamedTuple{names}) where {names} + exprs = [] + for f in names + push!(exprs, :($f = findinds(metadata.$f))) + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end + +# Get all indices of variables belonging to a given sampler +@inline function getidcs(vi::AbstractVarInfo, spl::Sampler) + return getidcs(vi, spl.selector, Val(getspace(spl))) +end +@inline getidcs(vi::UntypedVarInfo, s::Selector, space::Val) = findinds(vi.metadata, s, space) +@inline getidcs(vi::TypedVarInfo, s::Selector, space::Val) = getidcs(vi.metadata, s, space) +# Get a NamedTuple for all the indices belonging to a given selector for each symbol +@generated function getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} + exprs = [] + # Iterate through each varname in metadata. + for f in names + # If the varname is in the sampler space + # or the sample space is empty (all variables) + # then return the indices for that variable. + if inspace(f, space) || length(space) == 0 + push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) + end + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end +@inline function findinds(f_meta, s, ::Val{space}) where {space} + # Get all the idcs of the vns in `space` and that belong to the selector `s` + return filter((i) -> + (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && + (isempty(space) || inspace(f_meta.vns[i], space)), 1:length(f_meta.gids)) +end +@inline function findinds(f_meta) + # Get all the idcs of the vns + return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) +end + +# Get all vns of variables belonging to spl +getvns(vi::AbstractVarInfo, spl::Sampler) = getvns(vi, spl.selector, Val(getspace(spl))) +getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = getvns(vi, Selector(), Val(())) +function getvns(vi::UntypedVarInfo, s::Selector, space::Val) + idcs = getidcs(vi, s, space) + return vi.metadata.vns[idcs][.!(vi.metadata.flags["del"][idcs])] +end +function getvns(vi::TypedVarInfo, s::Selector, space::Val) + return getvns(vi.metadata, getidcs(vi, s, space)) +end +# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol +@generated function getvns(metadata, idcs::NamedTuple{names}) where {names} + exprs = [] + for f in names + push!(exprs, :($f = metadata.$f.vns[idcs.$f][.!(metadata.$f.flags["del"][idcs.$f])])) + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end + +# Get the index (in vals) ranges of all the vns of variables belonging to spl +@inline function getranges(vi::AbstractVarInfo, spl::Sampler) + ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} + #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end + #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 + # spl.info[:ranges] + #else + #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES + ranges = getranges(vi, spl.selector, Val(getspace(spl))) + #spl.info[:ranges] = ranges + return ranges + #end +end +# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` +@inline function getranges(vi::AbstractVarInfo, s::Selector, space) + return getranges(vi, getidcs(vi, s, space)) +end +@inline function getranges(vi::UntypedVarInfo, idcs::Vector{Int}) + mapreduce(i -> vi.metadata.ranges[i], vcat, idcs, init=Int[]) +end +@inline getranges(vi::TypedVarInfo, idcs::NamedTuple) = getranges(vi.metadata, idcs) + +@generated function getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} + exprs = [] + for f in names + push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end +@inline function findranges(f_ranges, f_idcs) + return mapreduce(i -> f_ranges[i], vcat, f_idcs, init=Int[]) +end + +""" + set_flag!(vi::VarInfo, vn::VarName, flag::String) + +Set `vn`'s value for `flag` to `true` in `vi`. +""" +function set_flag!(vi::AbstractVarInfo, vn::VarName, flag::String) + return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true +end + +# Functions defined only for UntypedVarInfo +""" + keys(vi::UntypedVarInfo) + +Return an iterator over all `vns` in `vi`. +""" +keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) + +""" + setgid!(vi::VarInfo, gid::Selector, vn::VarName) + +Add `gid` to the set of sampler selectors associated with `vn` in `vi`. +""" +function setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName; overwrite=false) + gids = getmetadata(vi, vn).gids[getidx(vi, vn)] + overwrite && empty!(gids) + push!(gids, gid) + return vi +end + +""" + istrans(vi::VarInfo, vn::VarName) + +Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if +they are in the support of `vn`'s distribution. +""" +function istrans(vi::AbstractVarInfo, vn::VarName) + return is_flagged(vi, vn, "trans") +end + +""" + getlogp(vi::VarInfo) + +Return the log of the joint probability of the observed data and parameters sampled in +`vi`. +""" +getlogp(vi::AbstractVarInfo) = vi.logp[] + +""" + setlogp!(vi::VarInfo, logp) + +Set the log of the joint probability of the observed data and parameters sampled in +`vi` to `logp`. +""" +function setlogp!(vi::VarInfo, logp) + vi.logp[] = logp + return vi +end + +""" + acclogp!(vi::VarInfo, logp) + +Add `logp` to the value of the log of the joint probability of the observed data and +parameters sampled in `vi`. +""" +function acclogp!(vi::VarInfo, logp) + vi.logp[] += logp + return vi +end + +""" + resetlogp!(vi::AbstractVarInfo) + +Reset the value of the log of the joint probability of the observed data and parameters +sampled in `vi` to 0. +""" +resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) + +""" + get_num_produce(vi::VarInfo) + +Return the `num_produce` of `vi`. +""" +get_num_produce(vi::VarInfo) = vi.num_produce[] + +""" + set_num_produce!(vi::VarInfo, n::Int) + +Set the `num_produce` field of `vi` to `n`. +""" +set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n + +""" + increment_num_produce!(vi::VarInfo) + +Add 1 to `num_produce` in `vi`. +""" +increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 + +""" + reset_num_produce!(vi::AbstractVarInfo) + +Reset the value of `num_produce` the log of the joint probability of the observed data +and parameters sampled in `vi` to 0. +""" +reset_num_produce!(vi::AbstractVarInfo) = set_num_produce!(vi, 0) + +""" + isempty(vi::VarInfo) + +Return true if `vi` is empty and false otherwise. +""" +isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) +isempty(vi::TypedVarInfo) = _isempty(vi.metadata) +@generated function _isempty(metadata::NamedTuple{names}) where {names} + expr = Expr(:&&, :true) + for f in names + push!(expr.args, :(isempty(metadata.$f.idcs))) + end + return expr +end + +""" + tonamedtuple(vi::VarInfo) + +Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and +indexing string of the variable. + +For example, a model that had a vector of vector-valued +variables `x` would return + +```julia +(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) +``` +""" +function tonamedtuple(vi::UntypedVarInfo) + return tonamedtuple(TypedVarInfo(vi)) +end +function tonamedtuple(vi::TypedVarInfo) + return tonamedtuple(vi.metadata, vi) +end +@generated function tonamedtuple(metadata::NamedTuple{names}, vi::VarInfo) where {names} + length(names) === 0 && return :(NamedTuple()) + expr = Expr(:tuple) + map(names) do f + push!(expr.args, Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns)))) + end + return expr +end + +function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler, SampleFromPrior}) + T = eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) + if T === Union{} + # To throw a meaningful error + return eltype(vi[spl]) + else + return T + end +end + +""" + haskey(vi::VarInfo, vn::VarName) + +Check whether `vn` has been sampled in `vi`. +""" +haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn).idcs, vn) +function haskey(vi::TypedVarInfo, vn::VarName) + metadata = vi.metadata + Tmeta = typeof(metadata) + return getsym(vn) in fieldnames(Tmeta) && haskey(getmetadata(vi, vn).idcs, vn) +end + +""" + hassymbol(vi::VarInfo, vn::VarName) + +Check whether the symbol of `vn` has been sampled in `vi`. +""" +hassymbol(vi::UntypedVarInfo, vn::VarName) = any(keys(vi)) do vn2 + getsym(vn) == getsym(vn2) +end +hassymbol(vi::TypedVarInfo, vn::VarName) = haskey(vi.metadata, getsym(vn)) + +""" + setorder!(vi::VarInfo, vn::VarName, index::Int) + +Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe +statements run before sampling `vn`. +""" +function setorder!(vi::VarInfo, vn::VarName, index::Int) + metadata = getmetadata(vi, vn) + if metadata.orders[metadata.idcs[vn]] != index + metadata.orders[metadata.idcs[vn]] = index + end + return vi +end + +####################################### +# Rand & replaying method for VarInfo # +####################################### + +""" + is_flagged(vi::VarInfo, vn::VarName, flag::String) + +Check whether `vn` has a true value for `flag` in `vi`. +""" +function is_flagged(vi::VarInfo, vn::VarName, flag::String) + return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] +end + +""" + unset_flag!(vi::VarInfo, vn::VarName, flag::String) + +Set `vn`'s value for `flag` to `false` in `vi`. +""" +function unset_flag!(vi::VarInfo, vn::VarName, flag::String) + return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false +end + +""" + set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) + +Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. +""" +function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) + # Get the indices of `vns` that belong to `spl` as a vector + gidcs = getidcs(vi, spl) + if get_num_produce(vi) == 0 + for i = length(gidcs):-1:1 + vi.metadata.flags["del"][gidcs[i]] = true + end + else + for i in 1:length(vi.orders) + if i in gidcs && vi.orders[i] > get_num_produce(vi) + vi.metadata.flags["del"][i] = true + end + end + end + return nothing +end +function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) + # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol + gidcs = getidcs(vi, spl) + return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) +end +@generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names} + expr = Expr(:block) + for f in names + f_gidcs = :(gidcs.$f) + f_orders = :(metadata.$f.orders) + f_flags = :(metadata.$f.flags) + push!(expr.args, quote + # Set the flag for variables with symbol `f` + if num_produce == 0 + for i = length($f_gidcs):-1:1 + $f_flags["del"][$f_gidcs[i]] = true + end + else + for i in 1:length($f_orders) + if i in $f_gidcs && $f_orders[i] > num_produce + $f_flags["del"][i] = true + end + end + end + end) + end + return expr +end + +""" + updategid!(vi::VarInfo, vn::VarName, spl::Sampler; overwrite=false) + +Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked +and `vn`'s symbol is in the space of `spl`. +""" +function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler; overwrite=false) + if inspace(vn, getspace(spl)) + setgid!(vi, spl.selector, vn; overwrite=overwrite) + end +end + +""" + set_namedtuple!(vi::VarInfo, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. +""" +function set_namedtuple!(vi::VarInfo, nt::NamedTuple) + @assert !islinked(vi) + for (n, vals) in pairs(nt) + if vi isa TypedVarInfo + vns = vi.metadata[n].vns + else + vns = vi.metadata.vns + end + for vn in vns + vi[vn] = _getindex(vals, vn.indexing) + end + end +end + +function updategid!(vi::AbstractVarInfo, spls::Tuple{Vararg{AbstractSampler}}; overwrite=false) + foreach(spls) do spl + updategid!(vi, spl; overwrite=overwrite) + end + return vi +end +function updategid!(vi::UntypedVarInfo, spl::AbstractSampler; overwrite=false) + vns = vi.metadata.vns + if inspace(vns[1], getspace(spl)) + for vn in vns + updategid!(vi, vn, spl; overwrite=overwrite) + end + end + return vi +end +function updategid!(vi::TypedVarInfo, spl::AbstractSampler; overwrite=false) + foreach(keys(vi.metadata)) do k + vns = vi.metadata[k].vns + if length(vns) > 0 && inspace(vns[1], getspace(spl)) + for vn in vns + updategid!(vi, vn, spl; overwrite=overwrite) + end + end + end + return vi +end + +function removedel!(vi::VarInfo) + removedel!(vi.metadata) + return vi +end +removedel!(md::NamedTuple{<:Any, <:Tuple{Vararg{Metadata}}}) = map(removedel!, md) +function removedel!(md::Metadata) + vns_to_remove = similar(md.vns, 0) + inds_to_keep = Int[] + new_idcs = empty(md.idcs) + i = 1 + for vn in md.vns + idx = md.idcs[vn] + if !(md.flags["del"][idx]) + push!(inds_to_keep, idx) + new_idcs[vn] = i + i += 1 + end + end + new_vns = md.vns[inds_to_keep] + new_dists = md.dists[inds_to_keep] + new_gids = md.gids[inds_to_keep] + new_orders = md.orders[inds_to_keep] + new_flags = Dict(k => md.flags[k][inds_to_keep] for k in keys(md.flags)) + + nvals = length(inds_to_keep) == 0 ? 0 : sum(length, md.ranges[inds_to_keep]) + new_vals = similar(md.vals, nvals) + new_trans_vals = similar(new_vals) + new_ranges = similar(md.ranges, length(inds_to_keep)) + last_ind = 0 + for (_i, i) in enumerate(inds_to_keep) + first_ind = last_ind + 1 + last_ind = last_ind + length(md.ranges[i]) + new_vals[first_ind:last_ind] = md.vals[md.ranges[i]] + new_trans_vals[first_ind:last_ind] = md.trans_vals[md.ranges[i]] + new_ranges[_i] = first_ind:last_ind + end + + md.idcs.age = new_idcs.age + md.idcs.count = new_idcs.count + md.idcs.idxfloor = new_idcs.idxfloor + copyto!(md.idcs.keys, new_idcs.keys) + resize!(md.idcs.keys, length(new_idcs.keys)) + md.idcs.maxprobe = new_idcs.maxprobe + copyto!(md.idcs.slots, new_idcs.slots) + resize!(md.idcs.slots, length(new_idcs.slots)) + copyto!(md.idcs.vals, new_idcs.vals) + resize!(md.idcs.vals, length(new_idcs.vals)) + + copyto!(md.vns, new_vns) + resize!(md.vns, length(new_vns)) + copyto!(md.dists, new_dists) + resize!(md.dists, length(new_dists)) + copyto!(md.gids, new_gids) + resize!(md.gids, length(new_gids)) + copyto!(md.orders, new_orders) + resize!(md.orders, length(new_orders)) + copyto!(md.ranges, new_ranges) + resize!(md.ranges, length(new_ranges)) + for k in keys(md.flags) + copyto!(md.flags[k], new_flags[k]) + resize!(md.flags[k], length(new_flags[k])) + end + copyto!(md.vals, new_vals) + resize!(md.vals, length(new_vals)) + copyto!(md.trans_vals, new_trans_vals) + resize!(md.trans_vals, length(new_trans_vals)) + + return md +end diff --git a/src/varinfo/varinfo.jl b/src/varinfo/varinfo.jl new file mode 100644 index 000000000..83c870435 --- /dev/null +++ b/src/varinfo/varinfo.jl @@ -0,0 +1,10 @@ +# Constants for caching +const CACHERESET = 0b00 +const CACHEIDCS = 0b10 +const CACHERANGES = 0b01 + +include("types.jl") +include("utils.jl") +include("indexing.jl") +include("linking.jl") +include("ad.jl") diff --git a/test/Turing/contrib/inference/dynamichmc.jl b/test/Turing/contrib/inference/dynamichmc.jl index 17d1221d9..0cdac493b 100644 --- a/test/Turing/contrib/inference/dynamichmc.jl +++ b/test/Turing/contrib/inference/dynamichmc.jl @@ -41,7 +41,7 @@ function DynamicNUTS{AD}(space::Symbol...) where AD DynamicNUTS{AD, space}() end -mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState +mutable struct DynamicNUTSState{V<:AbstractVarInfo, D} <: AbstractSamplerState vi::V draws::Vector{D} end @@ -53,29 +53,16 @@ function AbstractMCMC.sample_init!( model::Model, spl::Sampler{<:DynamicNUTS}, N::Integer; - kwargs... + kwargs..., ) # Set up lp function. function _lp(x) - gradient_logp(x, spl.state.vi, model, spl) + gradient_logp(x, link(spl.state.vi), model, spl) end # Set the parameters to a starting value. initialize_parameters!(spl; kwargs...) - - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - while !isfinite(l) || !isfinite(dl) - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - end - - if spl.selector.tag == :default && !islinked(spl.state.vi, spl) - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - end + link!(spl.state.vi, spl, model) results = mcmc_with_warmup( rng, @@ -99,7 +86,8 @@ function AbstractMCMC.step!( ) # Pop the next draw off the vector. draw = popfirst!(spl.state.draws) - spl.state.vi[spl] = draw + link(spl.state.vi)[spl] = draw + invlink!(spl.state.vi, spl, model) return Transition(spl) end @@ -109,7 +97,7 @@ function Sampler( s::Selector=Selector() ) # Construct a state, using a default function. - state = DynamicNUTSState(VarInfo(model), []) + state = DynamicNUTSState(DynamicPPL.TypedVarInfo(model), []) # Return a new sampler. return Sampler(alg, Dict{Symbol,Any}(), s, state) @@ -118,7 +106,7 @@ end # Disable the progress logging for DynamicHMC, since it has its own progress meter. function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, N::Integer; chain_type=MCMCChains.Chains, @@ -127,7 +115,7 @@ end kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end if resume_from === nothing return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; @@ -139,7 +127,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, @@ -149,7 +137,7 @@ function AbstractMCMC.sample( kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; chain_type=chain_type, progress=false, kwargs...) diff --git a/test/Turing/contrib/inference/sghmc.jl b/test/Turing/contrib/inference/sghmc.jl index 83c488613..b1f9dda05 100644 --- a/test/Turing/contrib/inference/sghmc.jl +++ b/test/Turing/contrib/inference/sghmc.jl @@ -61,13 +61,9 @@ function step( is_first::Val{true}; kwargs... ) - spl.selector.tag != :default && link!(vi, spl) - # Initialize velocity v = zeros(Float64, size(vi[spl])) spl.info[:v] = v - - spl.selector.tag != :default && invlink!(vi, spl) return vi, true end @@ -84,13 +80,12 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) - model(vi, spl) + model(initlink(vi), spl) end Turing.DEBUG && @debug "recording old variables..." θ, v = vi[spl], spl.info[:v] - _, grad = gradient_logp(θ, vi, model, spl) + _, grad = gradient_logp(θ, link(vi), model, spl) verifygrad(grad) # Implements the update equations from (15) of Chen et al. (2014). @@ -197,7 +192,7 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) + link!(vi, spl, model) model(vi, spl) end diff --git a/test/Turing/core/ad.jl b/test/Turing/core/ad.jl index b896fccdd..7d86f76af 100644 --- a/test/Turing/core/ad.jl +++ b/test/Turing/core/ad.jl @@ -60,7 +60,7 @@ getADbackend(spl::Sampler) = getADbackend(spl.alg) """ gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -71,7 +71,7 @@ tool is currently active. """ function gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::Sampler ) @@ -82,7 +82,7 @@ end gradient_logp( backend::ADBackend, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -93,7 +93,7 @@ specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{ function gradient_logp( ::ForwardDiffAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -120,7 +120,7 @@ end function gradient_logp( ::TrackerAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/test/Turing/core/compat/reversediff.jl b/test/Turing/core/compat/reversediff.jl index 5ccfbbebd..f3822d35e 100644 --- a/test/Turing/core/compat/reversediff.jl +++ b/test/Turing/core/compat/reversediff.jl @@ -17,7 +17,7 @@ end function gradient_logp( backend::ReverseDiffAD{false}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -54,7 +54,7 @@ end function gradient_logp( backend::ReverseDiffAD{true}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/test/Turing/core/compat/zygote.jl b/test/Turing/core/compat/zygote.jl index 3c56a1922..dc18fa0f8 100644 --- a/test/Turing/core/compat/zygote.jl +++ b/test/Turing/core/compat/zygote.jl @@ -7,7 +7,7 @@ end function gradient_logp( backend::ZygoteAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/test/Turing/inference/AdvancedSMC.jl b/test/Turing/inference/AdvancedSMC.jl index 77a4c7090..ce857fd12 100644 --- a/test/Turing/inference/AdvancedSMC.jl +++ b/test/Turing/inference/AdvancedSMC.jl @@ -22,6 +22,27 @@ struct ParticleTransition{T, F<:AbstractFloat} weight::F end +function Base.promote_type( + ::Type{ParticleTransition{T1, F1}}, + ::Type{ParticleTransition{T2, F2}}, +) where {T1, F1, T2, F2} + return ParticleTransition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{ParticleTransition{T, F}}, + t::ParticleTransition, +) where {T, F} + return ParticleTransition{T, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(F, t.le), + convert(F, t.weight), + ) +end + function additional_parameters(::Type{<:ParticleTransition}) return [:lp,:le, :weight] end @@ -69,7 +90,7 @@ SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, sp SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space) -mutable struct SMCState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F @@ -203,7 +224,7 @@ function PG(nparticles::Int, space::Tuple) return PG(nparticles, Turing.Core.ResampleWithESSThreshold(), space) end -mutable struct PGState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F @@ -214,6 +235,10 @@ function PGState(model::Model) return PGState(vi, 0.0) end +function replace_varinfo(s::PGState, vi::AbstractVarInfo) + return PGState(vi, s.average_logevidence) +end + const CSMC = PG # type alias of PG as Conditional SMC """ @@ -319,23 +344,21 @@ function DynamicPPL.assume( r = rand(dist) push!(vi, vn, r, dist, spl) elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + DynamicPPL.removedel!(vi) r = rand(dist) - vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) - setorder!(vi, vn, get_num_produce(vi)) + push!(vi, vn, r, dist, spl) else updategid!(vi, vn, spl) - r = vi[vn] + r = vi[vn, dist] end else # vn belongs to other sampler <=> conditionning on vn if haskey(vi, vn) - r = vi[vn] + r = vi[vn, dist] else r = rand(dist) push!(vi, vn, r, dist, Selector(:invalid)) end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) + lp = logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) acclogp!(vi, lp) end return r, 0 diff --git a/test/Turing/inference/Inference.jl b/test/Turing/inference/Inference.jl index b3db2982d..c3fcd1fad 100644 --- a/test/Turing/inference/Inference.jl +++ b/test/Turing/inference/Inference.jl @@ -3,9 +3,9 @@ module Inference using ..Core using ..Core: logZ using ..Utilities -using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, - islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, AbstractSampler, +using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, set_namedtuple!, + islinked_and_trans, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, + settrans!, getvns, getinitdist, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, @@ -20,13 +20,14 @@ using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler using Bijectors: _debug using DocStringExtensions: TYPEDEF, TYPEDFIELDS +import BangBang import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADbackend import DynamicPPL: get_matching_type, - VarName, _getranges, _getindex, getval, _getvns + VarName, getval, getvns import EllipticalSliceSampling import Random import MCMCChains @@ -74,6 +75,12 @@ getADbackend(::Hamiltonian{AD}) where AD = AD() # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end +getindex(vi::MixedVarInfo, spl::Sampler{<:Hamiltonian}) = vi.tvi[spl] +function setindex!(vi::MixedVarInfo, val, spl::Sampler{<:Hamiltonian}) + vi.tvi[spl] = val + return vi +end + """ mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real) @@ -101,6 +108,22 @@ struct Transition{T, F<:AbstractFloat} lp :: F end +function Base.promote_type( + ::Type{Transition{T1, F1}}, + ::Type{Transition{T2, F2}}, +) where {T1, F1, T2, F2} + return Transition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{Transition{T, F}}, + t::Transition, +) where {T, F} + return Transition{T, F}(convert(T, t.θ), convert(F, t.lp)) +end + function Transition(spl::Sampler, nt::NamedTuple=NamedTuple()) theta = merge(tonamedtuple(spl.state.vi), nt) lp = getlogp(spl.state.vi) @@ -272,7 +295,6 @@ function initialize_parameters!( verbose::Bool=false, kwargs... ) - islinked(spl.state.vi, spl) && invlink!(spl.state.vi, spl) # Get `init_theta` if init_theta !== nothing verbose && @info "Using passed-in initial variable values" init_theta @@ -329,7 +351,8 @@ function flatten_namedtuple(nt::NamedTuple) if length(v) == 1 return [(string(k), v)] else - return mapreduce(vcat, zip(v[1], v[2])) do (vnval, vn) + init = Tuple{String, eltype(v[1])}[] + return mapreduce(vcat, zip(v[1], v[2]); init = init) do (vnval, vn) return collect(FlattenIterator(vn, vnval)) end end @@ -520,9 +543,12 @@ end """ A blank `AbstractSamplerState` that contains only `VarInfo` information. """ -mutable struct SamplerState{VIType<:VarInfo} <: AbstractSamplerState +mutable struct SamplerState{VIType<:AbstractVarInfo} <: AbstractSamplerState vi :: VIType end +function replace_varinfo(::SamplerState, vi::AbstractVarInfo) + return SamplerState(vi) +end ####################################### # Concrete algorithm implementations. # @@ -548,6 +574,7 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) end floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) +floatof(::Type{Real}) = Real floatof(::Type) = Real # fallback if type inference failed function get_matching_type( @@ -566,9 +593,9 @@ function get_matching_type( end function get_matching_type( spl::AbstractSampler, - vi, + vi, ::Type{<:AbstractFloat}, -) +) where {T} return floatof(eltype(vi, spl)) end function get_matching_type( diff --git a/test/Turing/inference/ess.jl b/test/Turing/inference/ess.jl index 27a3b3f54..a790c51cd 100644 --- a/test/Turing/inference/ess.jl +++ b/test/Turing/inference/ess.jl @@ -25,19 +25,22 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() -mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState +mutable struct ESSState{V<:AbstractVarInfo} <: AbstractSamplerState vi::V end +function replace_varinfo(::ESSState, vi::AbstractVarInfo) + return ESSState(vi) +end function Sampler(alg::ESS, model::Model, s::Selector) # sanity check vi = VarInfo(model) space = getspace(alg) - vns = _getvns(vi, s, Val(space)) + vns = getvns(vi, s, Val(space)) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") for vn in vns[1] - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("[ESS] only supports Gaussian prior distributions") end @@ -102,9 +105,9 @@ end function ESSModel(model::Model, spl::Sampler{<:ESS}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) μ = mapreduce(vcat, vns[1]) do vn - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) vectorize(dist, mean(dist)) end @@ -115,7 +118,7 @@ end function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel) spl = model.spl vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) set_flag!(vi, vns[1][1], "del") model.model(vi, spl) return vi[spl] diff --git a/test/Turing/inference/gibbs.jl b/test/Turing/inference/gibbs.jl index 4b82cb934..d327e099b 100644 --- a/test/Turing/inference/gibbs.jl +++ b/test/Turing/inference/gibbs.jl @@ -42,12 +42,12 @@ function Gibbs(algs::GibbsComponent...) end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that the `Gibbs` sampler iterates through for each `step!`. """ -mutable struct GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState +mutable struct GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState vi::V samplers::S end @@ -56,6 +56,10 @@ function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}) return GibbsState(VarInfo(model), samplers) end +function replace_varinfo(s::GibbsState, vi::AbstractVarInfo) + return GibbsState(vi, s.samplers) +end + function Sampler(alg::Gibbs, model::Model, s::Selector) # sanity check for space space = getspace(alg) @@ -72,28 +76,24 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) selector = Selector(Symbol(typeof(_alg)), rerun) Sampler(_alg, model, selector) end + varinfo = merge(ntuple(i -> samplers[i].state.vi, Val(length(samplers)))...) + samplers = map(samplers) do sampler + Sampler( + sampler.alg, + sampler.info, + sampler.selector, + replace_varinfo(sampler.state, varinfo), + ) + end # create a state variable - state = GibbsState(model, samplers) + state = GibbsState(varinfo, samplers) # create the sampler info = Dict{Symbol, Any}() spl = Sampler(alg, info, s, state) # add Gibbs to gids for all variables - vi = spl.state.vi - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns - - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) - - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end - end + DynamicPPL.updategid!(varinfo, (spl, samplers...)) return spl end @@ -112,6 +112,27 @@ struct GibbsTransition{T,F,S<:AbstractVector} transitions::S end +function Base.promote_type( + ::Type{GibbsTransition{T1, F1, S1}}, + ::Type{GibbsTransition{T2, F2, S2}}, +) where {T1, F1, S1, T2, F2, S2} + return GibbsTransition{ + Union{T1, T2}, + promote_type(F1, F2), + promote_type(S1, S2), + } +end +function Base.convert( + ::Type{GibbsTransition{T, F, S}}, + t::GibbsTransition, +) where {T, F, S} + return GibbsTransition{T, F, S}( + convert(T, t.θ), + convert(F, t.lp), + convert(S, t.transitions), + ) +end + function GibbsTransition(spl::Sampler{<:Gibbs}, transitions::AbstractVector) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -188,25 +209,26 @@ function AbstractMCMC.step!( end # Do not store transitions of subsamplers -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition::GibbsTransition, ::Model, ::Sampler{<:Gibbs}, N::Integer; kwargs... ) - return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N) + ts = Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, 0) + sizehint!(ts, N) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::Vector{<:Transition}, - iteration::Integer, transition::GibbsTransition, + iteration::Integer, ::Model, ::Sampler{<:Gibbs}, ::Integer; kwargs... ) - transitions[iteration] = Transition(transition.θ, transition.lp) - return + return BangBang.push!!(transitions, Transition(transition.θ, transition.lp)) end diff --git a/test/Turing/inference/hmc.jl b/test/Turing/inference/hmc.jl index 3b8b95c95..9324f7f80 100644 --- a/test/Turing/inference/hmc.jl +++ b/test/Turing/inference/hmc.jl @@ -3,7 +3,7 @@ ### mutable struct HMCState{ - TV <: TypedVarInfo, + TV <: AbstractVarInfo, TTraj<:AHMC.AbstractTrajectory, TAdapt<:AHMC.Adaptation.AbstractAdaptor, PhType <: AHMC.PhasePoint @@ -17,6 +17,18 @@ mutable struct HMCState{ z :: PhType end +function replace_varinfo(s::HMCState, vi::AbstractVarInfo) + return HMCState( + vi, + s.eval_num, + s.i, + s.traj, + s.h, + s.adaptor, + s.z, + ) +end + ########################## # Hamiltonian Transition # ########################## @@ -27,6 +39,27 @@ struct HamiltonianTransition{T, NT<:NamedTuple, F<:AbstractFloat} stat :: NT end +function Base.promote_type( + ::Type{HamiltonianTransition{T1, NT1, F1}}, + ::Type{HamiltonianTransition{T2, NT2, F2}}, +) where {T1, NT1, F1, T2, NT2, F2} + return HamiltonianTransition{ + Union{T1, T2}, + promote_type(NT1, NT2), + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{HamiltonianTransition{T, NT, F}}, + t::HamiltonianTransition, +) where {T, NT, F} + return HamiltonianTransition{T, NT, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(NT, t.stat) + ) +end + function HamiltonianTransition(spl::Sampler{<:Hamiltonian}, t::AHMC.Transition) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -102,8 +135,8 @@ end function update_hamiltonian!(spl, model, n) metric = gen_metric(n, spl) - ℓπ = gen_logπ(spl.state.vi, spl, model) - ∂ℓπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) + ℓπ = gen_logπ(link(spl.state.vi), spl, model) + ∂ℓπ∂θ = gen_∂logπ∂θ(link(spl.state.vi), spl, model) spl.state.h = AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) return spl end @@ -125,24 +158,24 @@ function AbstractMCMC.sample_init!( initialize_parameters!(spl; verbose=verbose, kwargs...) if init_theta !== nothing # Doesn't support dynamic models - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) - else + elseif resume_from === nothing # Samples new values and sets trans to true, then computes the logp model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) while !isfinite(spl.state.z.ℓπ.value) || !isfinite(spl.state.z.ℓπ.gradient) model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) @@ -166,19 +199,19 @@ function AbstractMCMC.sample_init!( spl.alg.n_adapts = 0 end end - + # Convert to transformed space if we're using # non-Gibbs sampling. - if !islinked(spl.state.vi, spl) && spl.selector.tag == :default - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default - invlink!(spl.state.vi, spl) + if spl.selector.tag == :default + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + else + invlink!(spl.state.vi, spl, model) model(spl.state.vi, spl) end end -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, @@ -191,28 +224,26 @@ function AbstractMCMC.transitions_init( else n = N end - return Vector{typeof(transition)}(undef, n) + ts = Vector{typeof(transition)}(undef, 0) + sizehint!(ts, n) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::AbstractVector, - iteration::Integer, transition, + iteration::Integer, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, ::Integer; discard_adapt = true, kwargs... ) - if discard_adapt && isdefined(sampler.alg, :n_adapts) - if iteration > sampler.alg.n_adapts - transitions[iteration - sampler.alg.n_adapts] = transition - end - return + if discard_adapt && isdefined(sampler.alg, :n_adapts) && iteration <= sampler.alg.n_adapts + return transitions + else + return BangBang.push!!(transitions, transition) end - - transitions[iteration] = transition - return end """ @@ -375,7 +406,8 @@ function Sampler( ) info = Dict{Symbol, Any}() # Create an empty sampler state that just holds a typed VarInfo. - initial_state = SamplerState(VarInfo(model)) + varinfo = getspace(alg) === () ? TypedVarInfo(model) : VarInfo(model) + initial_state = SamplerState(varinfo) # Create an initial sampler, to get all the initialization out of the way. initial_spl = Sampler(alg, info, s, initial_state) @@ -411,20 +443,22 @@ function AbstractMCMC.step!( spl.state.eval_num = 0 Turing.DEBUG && @debug "current ϵ: $ϵ" + updategid!(spl.state.vi, spl) # When a Gibbs component if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." - link!(spl.state.vi, spl) - model(spl.state.vi, spl) + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) end # Get position and log density before transition - θ_old, log_density_old = spl.state.vi[spl], getlogp(spl.state.vi) + θ_old, θ_old_trans = spl.state.vi[spl], link(spl.state.vi)[spl] + log_density_old = getlogp(spl.state.vi) if spl.selector.tag != :default - update_hamiltonian!(spl, model, length(θ_old)) + update_hamiltonian!(spl, model, length(θ_old_trans)) resize!(spl.state.z.θ, length(θ_old)) - spl.state.z.θ .= θ_old + spl.state.z.θ .= θ_old_trans end # Transition @@ -443,17 +477,19 @@ function AbstractMCMC.step!( # Update `vi` based on acceptance if t.stat.is_accept - spl.state.vi[spl] = t.z.θ + link(spl.state.vi)[spl] = t.z.θ + invlink!(spl.state.vi, spl, model) setlogp!(spl.state.vi, t.stat.log_density) else spl.state.vi[spl] = θ_old + link(spl.state.vi)[spl] = θ_old_trans setlogp!(spl.state.vi, log_density_old) + DynamicPPL.setsynced!(spl.state.vi, true) end # Gibbs component specified cares # Transform the space back Turing.DEBUG && @debug "R -> X..." - spl.selector.tag != :default && invlink!(spl.state.vi, spl) return HamiltonianTransition(spl, t) end @@ -514,13 +550,13 @@ function DynamicPPL.assume( ) Turing.DEBUG && _debug("assuming...") updategid!(vi, vn, spl) - r = vi[vn] - # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) + r = vi[vn, dist] + # acclogp!(vi, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))) # r Turing.DEBUG && _debug("dist = $dist") Turing.DEBUG && _debug("vn = $vn") Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))") - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -532,9 +568,9 @@ function DynamicPPL.dot_assume( ) @assert length(dist) == size(var, 1) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:Hamiltonian}, @@ -544,9 +580,9 @@ function DynamicPPL.dot_assume( vi, ) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( @@ -606,11 +642,11 @@ function HMCState( vi = spl.state.vi # Link everything if needed. - !islinked(vi, spl) && link!(vi, spl) + link!(vi, spl, model) # Get the initial log pdf and gradient functions. - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = gen_logπ(vi, spl, model) + ∂logπ∂θ = gen_∂logπ∂θ(link(vi), spl, model) + logπ = gen_logπ(link(vi), spl, model) # Get the metric type. metricT = getmetricT(spl.alg) @@ -635,7 +671,7 @@ function HMCState( h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. # Unlink everything. - invlink!(vi, spl) + invlink!(vi, spl, model) return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) end diff --git a/test/Turing/inference/is.jl b/test/Turing/inference/is.jl index a7f515e36..dcd4cad5b 100644 --- a/test/Turing/inference/is.jl +++ b/test/Turing/inference/is.jl @@ -31,7 +31,7 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() -mutable struct ISState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct ISState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V final_logevidence :: F end diff --git a/test/Turing/inference/mh.jl b/test/Turing/inference/mh.jl index 764c8959b..02905b1f5 100644 --- a/test/Turing/inference/mh.jl +++ b/test/Turing/inference/mh.jl @@ -54,49 +54,12 @@ alg_str(::Sampler{<:MH}) = "MH" # Utility functions # ##################### -""" - set_namedtuple!(vi::VarInfo, nt::NamedTuple) - -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. -""" -function set_namedtuple!(vi::VarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - - n_vns = length(vns) - n_vals = length(vals) - v_isarr = vals isa AbstractArray - - if v_isarr && n_vals == 1 && n_vns > 1 - for (vn, val) in zip(vns, vals[1]) - vi[vn] = val isa AbstractArray ? val : [val] - end - elseif v_isarr && n_vals > 1 && n_vns == 1 - vi[vns[1]] = vals - elseif v_isarr && n_vals == n_vns > 1 - for (vn, val) in zip(vns, vals) - vi[vn] = [val] - end - elseif v_isarr && n_vals == 1 && n_vns == 1 - if vals[1] isa AbstractArray - vi[vns[1]] = vals[1] - else - vi[vns[1]] = [vals[1]] - end - elseif !(v_isarr) - vi[vns[1]] = [vals] - else - error("Cannot assign `NamedTuple` to `VarInfo`") - end - end -end - """ MHLogDensityFunction A log density function for the MH sampler. -This variant uses the `set_namedtuple!` function to update the `VarInfo`. +This variant uses the `set_namedtuple!` function to update the variables. """ struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH}} <: Function # Relax AMH.DensityModel? model::M @@ -136,20 +99,20 @@ The second `NamedTuple` has model symbols as keys and their stored values as val """ function dist_val_tuple(spl::Sampler{<:MH}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) dt = _dist_tuple(spl.alg.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt end @generated function _val_tuple( - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names} isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ - :($name = reconstruct(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)), + :($name = reconstruct(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)), DynamicPPL.getval(vi, vns.$name))) for name in names] return expr @@ -157,7 +120,7 @@ end @generated function _dist_tuple( props::NamedTuple{propnames}, - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names,propnames} isempty(names) === 0 && return :(NamedTuple()) @@ -168,7 +131,7 @@ end :($name = props.$name) else # Otherwise, use the default proposal. - :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)))) + :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)))) end for name in names] return expr end @@ -229,8 +192,8 @@ function DynamicPPL.assume( vi, ) updategid!(vi, vn, spl) - r = vi[vn] - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + r = vi[vn, dist] + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -244,9 +207,9 @@ function DynamicPPL.dot_assume( getvn = i -> VarName(vn, vn.indexing * "[:,$i]") vns = getvn.(1:size(var, 2)) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:MH}, @@ -258,9 +221,9 @@ function DynamicPPL.dot_assume( getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") vns = getvn.(CartesianIndices(var)) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( diff --git a/test/Turing/utilities/robustinit.jl b/test/Turing/utilities/robustinit.jl deleted file mode 100644 index f78394b61..000000000 --- a/test/Turing/utilities/robustinit.jl +++ /dev/null @@ -1,33 +0,0 @@ -# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni() = Real(2rand()) -randrealuni(args...) = map(Real, 2rand(args...)) - -const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution} - - -################################# -# Single-sample initialisations # -################################# - -init(dist::Transformable) = inittrans(dist) -init(dist::Distribution) = rand(dist) - -inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni()) -inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1])) -inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...)) - - -################################ -# Multi-sample initialisations # -################################ - -init(dist::Transformable, n::Int) = inittrans(dist, n) -init(dist::Distribution, n::Int) = rand(dist, n) - -inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n)) -function inittrans(dist::MultivariateDistribution, n::Int) - return invlink(dist, randrealuni(size(dist)[1], n)) -end -function inittrans(dist::MatrixDistribution, n::Int) - return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n]) -end diff --git a/test/Turing/variational/advi.jl b/test/Turing/variational/advi.jl index c8e58e6c5..26c861a09 100644 --- a/test/Turing/variational/advi.jl +++ b/test/Turing/variational/advi.jl @@ -62,7 +62,7 @@ Creates a mean-field approximation with multivariate normal as underlying distri meanfield(model::Model) = meanfield(GLOBAL_RNG, model) function meanfield(rng::AbstractRNG, model::Model) # setup - varinfo = Turing.VarInfo(model) + varinfo = DynamicPPL.TypedVarInfo(model) num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)]) diff --git a/test/compiler.jl b/test/compiler.jl index da180baea..12b45e852 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -581,12 +581,12 @@ end vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test vi1.metadata.y.vns[1] == VarName(:y) - @test haskey(vi2.metadata, :y) - @test vi2.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1))) - @test haskey(vi3.metadata, :y) - @test vi3.metadata.y.vns[1] == VarName(:y, ((1,),)) + @test haskey(vi1.tvi.metadata, :y) + @test vi1.tvi.metadata.y.vns[1] == VarName(:y) + @test haskey(vi2.tvi.metadata, :y) + @test vi2.tvi.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1))) + @test haskey(vi3.tvi.metadata, :y) + @test vi3.tvi.metadata.y.vns[1] == VarName(:y, ((1,),)) end @testset "custom tilde" begin @model demo() = begin diff --git a/test/test_util.jl b/test/test_util.jl index d8a926656..05127ee2c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -1,5 +1,7 @@ +using DynamicPPL: TypedVarInfo + function test_model_ad(model, logp_manual) - vi = VarInfo(model) + vi = TypedVarInfo(model) model(vi, SampleFromPrior()) x = DynamicPPL.getall(vi) diff --git a/test/varinfo.jl b/test/varinfo.jl index 5ec939807..debf8449c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -2,10 +2,10 @@ using .Turing, Random using AbstractMCMC: step! using DynamicPPL: Selector, reconstruct, invlink, CACHERESET, SampleFromPrior, Sampler, SampleFromUniform, - _getidcs, set_retained_vns_del_by_spl!, is_flagged, + getidcs, set_retained_vns_del_by_spl!, is_flagged, set_flag!, unset_flag!, VarInfo, TypedVarInfo, getlogp, setlogp!, resetlogp!, acclogp!, vectorize, - setorder!, updategid! + setorder!, updategid!, islinked_and_trans, link using DynamicPPL, LinearAlgebra using Distributions using ForwardDiff: Dual @@ -82,7 +82,7 @@ include(dir*"/test/test_utils/AllUtils.jl") @test vi[vn] == r @test vi[SampleFromPrior()][1] == r - vi[vn] = [2*r] + vi[vn] = 2*r @test vi[vn] == 2*r @test vi[SampleFromPrior()][1] == 2*r vi[SampleFromPrior()] = [3*r] @@ -168,30 +168,48 @@ include(dir*"/test/test_utils/AllUtils.jl") model(vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.vns) + @test all(x -> !islinked_and_trans(link(vi), x), meta.vns) alg = HMC(0.1, 5) spl = Sampler(alg, model) v = copy(meta.vals) - link!(vi, spl) + link!(vi, spl, model) @test all(x -> istrans(vi, x), meta.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.vns) + invlink!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.vns) @test meta.vals == v vi = TypedVarInfo(vi) meta = vi.metadata alg = HMC(0.1, 5) spl = Sampler(alg, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.s.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.m.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!(vi, spl) + link!(vi, spl, model) @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.s.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl) - @test all(x -> ~istrans(vi, x), meta.s.vns) - @test all(x -> ~istrans(vi, x), meta.m.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.m.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.m.vns) + invlink!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.s.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + @test all(x -> !islinked_and_trans(vi, x), meta.m.vns) + @test all(x -> islinked_and_trans(link(vi), x), meta.m.vns) @test meta.s.vals == v_s @test meta.m.vals == v_m end @@ -225,12 +243,12 @@ include(dir*"/test/test_utils/AllUtils.jl") elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) + vi[vn, dist] = r setorder!(vi, vn, get_num_produce(vi)) r else updategid!(vi, vn, spl) - vi[vn] + vi[vn, dist] end end @@ -416,7 +434,7 @@ include(dir*"/test/test_utils/AllUtils.jl") @test sum(val - r) <= 1e-9 end - idcs = _getidcs(vi, spl1) + idcs = getidcs(vi, spl1) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 else @@ -424,7 +442,7 @@ include(dir*"/test/test_utils/AllUtils.jl") end @test length(vi[spl1]) == 7 - idcs = _getidcs(vi, spl2) + idcs = getidcs(vi, spl2) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 else @@ -435,7 +453,7 @@ include(dir*"/test/test_utils/AllUtils.jl") vn_u = @varname u randr(vi, vn_u, dists[1], spl2, true) - idcs = _getidcs(vi, spl2) + idcs = getidcs(vi, spl2) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 else @@ -468,25 +486,30 @@ include(dir*"/test/test_utils/AllUtils.jl") g_demo_f(vi, SampleFromPrior()) step!(Random.GLOBAL_RNG, g_demo_f, pg, 1) vi1 = pg.state.vi - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()] + vi1.tvi.metadata.x.gids[1] + vi1.tvi.metadata.y.gids[1] + vi1.tvi.metadata.z.gids[1] + vi1.tvi.metadata.w.gids[1] + vi1.tvi.metadata.u.gids[1] + @test mapreduce(x -> x.gids, vcat, vi1.tvi.metadata) == + [Set([g.selector, pg.selector]), Set([g.selector, pg.selector]), Set([g.selector, pg.selector]), Set([g.selector, hmc.selector]), Set([g.selector, hmc.selector])] @inferred g_demo_f(vi1, hmc) - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])] + @test mapreduce(x -> x.gids, vcat, vi1.tvi.metadata) == + [Set([g.selector, pg.selector]), Set([g.selector, pg.selector]), Set([g.selector, pg.selector]), Set([g.selector, hmc.selector]), Set([g.selector, hmc.selector])] g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) pg, hmc = g.state.samplers - vi = empty!(TypedVarInfo(vi)) + vi = empty!(MixedVarInfo(vi)) @inferred g_demo_f(vi, SampleFromPrior()) pg.state.vi = vi step!(Random.GLOBAL_RNG, g_demo_f, pg, 1) vi = pg.state.vi @inferred g_demo_f(vi, hmc) - @test vi.metadata.x.gids[1] == Set([pg.selector]) - @test vi.metadata.y.gids[1] == Set([pg.selector]) - @test vi.metadata.z.gids[1] == Set([pg.selector]) - @test vi.metadata.w.gids[1] == Set([hmc.selector]) - @test vi.metadata.u.gids[1] == Set([hmc.selector]) + @test vi.tvi.metadata.x.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.y.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.z.gids[1] == Set([pg.selector]) + @test vi.tvi.metadata.w.gids[1] == Set([hmc.selector]) + @test vi.tvi.metadata.u.gids[1] == Set([hmc.selector]) end end