From 04633f130219a41f52d62c3f0acbf65b44c5ff07 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 13 May 2020 09:38:21 +1000 Subject: [PATCH 01/15] fix for distributions with stochastic support --- src/contrib/inference/dynamichmc.jl | 30 ++++-------- src/contrib/inference/sghmc.jl | 11 ++--- src/inference/AdvancedSMC.jl | 8 +-- src/inference/Inference.jl | 7 ++- src/inference/ess.jl | 10 ++-- src/inference/hmc.jl | 73 +++++++++++++++------------- src/inference/mh.jl | 55 ++++----------------- test/inference/Inference.jl | 6 +-- test/inference/gibbs.jl | 1 - test/inference/hmc.jl | 66 +++++++++++++++++++++++++ test/runtests.jl | 45 ++++++++--------- test/test_utils/testing_functions.jl | 2 +- 12 files changed, 164 insertions(+), 150 deletions(-) diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 17d1221d9e..66829b01b5 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -53,29 +53,16 @@ function AbstractMCMC.sample_init!( model::Model, spl::Sampler{<:DynamicNUTS}, N::Integer; - kwargs... + kwargs..., ) # Set up lp function. function _lp(x) - gradient_logp(x, spl.state.vi, model, spl) + gradient_logp(x, link(spl.state.vi), model, spl) end # Set the parameters to a starting value. initialize_parameters!(spl; kwargs...) - - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - while !isfinite(l) || !isfinite(dl) - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - end - - if spl.selector.tag == :default && !islinked(spl.state.vi, spl) - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - end + link!(spl.state.vi, spl, model) results = mcmc_with_warmup( rng, @@ -99,7 +86,8 @@ function AbstractMCMC.step!( ) # Pop the next draw off the vector. draw = popfirst!(spl.state.draws) - spl.state.vi[spl] = draw + link(spl.state.vi)[spl] = draw + invlink!(spl.state.vi, spl, model) return Transition(spl) end @@ -118,7 +106,7 @@ end # Disable the progress logging for DynamicHMC, since it has its own progress meter. function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, N::Integer; chain_type=MCMCChains.Chains, @@ -127,7 +115,7 @@ end kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end if resume_from === nothing return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; @@ -139,7 +127,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, @@ -149,7 +137,7 @@ function AbstractMCMC.sample( kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; chain_type=chain_type, progress=false, kwargs...) diff --git a/src/contrib/inference/sghmc.jl b/src/contrib/inference/sghmc.jl index 83c4886132..b1f9dda050 100644 --- a/src/contrib/inference/sghmc.jl +++ b/src/contrib/inference/sghmc.jl @@ -61,13 +61,9 @@ function step( is_first::Val{true}; kwargs... ) - spl.selector.tag != :default && link!(vi, spl) - # Initialize velocity v = zeros(Float64, size(vi[spl])) spl.info[:v] = v - - spl.selector.tag != :default && invlink!(vi, spl) return vi, true end @@ -84,13 +80,12 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) - model(vi, spl) + model(initlink(vi), spl) end Turing.DEBUG && @debug "recording old variables..." θ, v = vi[spl], spl.info[:v] - _, grad = gradient_logp(θ, vi, model, spl) + _, grad = gradient_logp(θ, link(vi), model, spl) verifygrad(grad) # Implements the update equations from (15) of Chen et al. (2014). @@ -197,7 +192,7 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) + link!(vi, spl, model) model(vi, spl) end diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 77a4c70909..5acfeb5dd6 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -321,21 +321,21 @@ function DynamicPPL.assume( elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) + vi[vn, dist] = r setgid!(vi, spl.selector, vn) setorder!(vi, vn, get_num_produce(vi)) else updategid!(vi, vn, spl) - r = vi[vn] + r = vi[vn, dist] end else # vn belongs to other sampler <=> conditionning on vn if haskey(vi, vn) - r = vi[vn] + r = vi[vn, dist] else r = rand(dist) push!(vi, vn, r, dist, Selector(:invalid)) end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) + lp = logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) acclogp!(vi, lp) end return r, 0 diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 9510d96857..c4c8e86c05 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -3,9 +3,9 @@ module Inference using ..Core using ..Core: logZ using ..Utilities -using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, +using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, set_namedtuple!, islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, AbstractSampler, + settrans!, getvns, getinitdist, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, @@ -26,7 +26,7 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADbackend import DynamicPPL: get_matching_type, - VarName, _getranges, _getindex, getval, _getvns + VarName, getval, getvns import EllipticalSliceSampling import Random import MCMCChains @@ -272,7 +272,6 @@ function initialize_parameters!( verbose::Bool=false, kwargs... ) - islinked(spl.state.vi, spl) && invlink!(spl.state.vi, spl) # Get `init_theta` if init_theta !== nothing verbose && @info "Using passed-in initial variable values" init_theta diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 27a3b3f541..c649123e61 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -33,11 +33,11 @@ function Sampler(alg::ESS, model::Model, s::Selector) # sanity check vi = VarInfo(model) space = getspace(alg) - vns = _getvns(vi, s, Val(space)) + vns = getvns(vi, s, Val(space)) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") for vn in vns[1] - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("[ESS] only supports Gaussian prior distributions") end @@ -102,9 +102,9 @@ end function ESSModel(model::Model, spl::Sampler{<:ESS}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) μ = mapreduce(vcat, vns[1]) do vn - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) vectorize(dist, mean(dist)) end @@ -115,7 +115,7 @@ end function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel) spl = model.spl vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) set_flag!(vi, vns[1][1], "del") model.model(vi, spl) return vi[spl] diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 3b8b95c95e..aae222808f 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -102,8 +102,8 @@ end function update_hamiltonian!(spl, model, n) metric = gen_metric(n, spl) - ℓπ = gen_logπ(spl.state.vi, spl, model) - ∂ℓπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) + ℓπ = gen_logπ(link(spl.state.vi), spl, model) + ∂ℓπ∂θ = gen_∂logπ∂θ(link(spl.state.vi), spl, model) spl.state.h = AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) return spl end @@ -125,24 +125,24 @@ function AbstractMCMC.sample_init!( initialize_parameters!(spl; verbose=verbose, kwargs...) if init_theta !== nothing # Doesn't support dynamic models - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) - else + elseif resume_from === nothing # Samples new values and sets trans to true, then computes the logp model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) while !isfinite(spl.state.z.ℓπ.value) || !isfinite(spl.state.z.ℓπ.gradient) model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) @@ -166,14 +166,14 @@ function AbstractMCMC.sample_init!( spl.alg.n_adapts = 0 end end - + # Convert to transformed space if we're using # non-Gibbs sampling. - if !islinked(spl.state.vi, spl) && spl.selector.tag == :default - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default - invlink!(spl.state.vi, spl) + if spl.selector.tag == :default + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + else + invlink!(spl.state.vi, spl, model) model(spl.state.vi, spl) end end @@ -416,15 +416,17 @@ function AbstractMCMC.step!( if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." - link!(spl.state.vi, spl) - model(spl.state.vi, spl) + updategid!(spl.state.vi, spl) + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) end # Get position and log density before transition - θ_old, log_density_old = spl.state.vi[spl], getlogp(spl.state.vi) + θ_old, θ_old_trans = spl.state.vi[spl], link(spl.state.vi)[spl] + log_density_old = getlogp(spl.state.vi) if spl.selector.tag != :default - update_hamiltonian!(spl, model, length(θ_old)) + update_hamiltonian!(spl, model, length(θ_old_trans)) resize!(spl.state.z.θ, length(θ_old)) - spl.state.z.θ .= θ_old + spl.state.z.θ .= θ_old_trans end # Transition @@ -443,17 +445,19 @@ function AbstractMCMC.step!( # Update `vi` based on acceptance if t.stat.is_accept - spl.state.vi[spl] = t.z.θ + link(spl.state.vi)[spl] = t.z.θ + invlink!(spl.state.vi, spl, model) setlogp!(spl.state.vi, t.stat.log_density) else spl.state.vi[spl] = θ_old + link(spl.state.vi)[spl] = θ_old_trans setlogp!(spl.state.vi, log_density_old) + DynamicPPL.setsynced!(spl.state.vi, true) end # Gibbs component specified cares # Transform the space back Turing.DEBUG && @debug "R -> X..." - spl.selector.tag != :default && invlink!(spl.state.vi, spl) return HamiltonianTransition(spl, t) end @@ -513,14 +517,13 @@ function DynamicPPL.assume( vi, ) Turing.DEBUG && _debug("assuming...") - updategid!(vi, vn, spl) - r = vi[vn] - # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) + r = vi[vn, dist] + # acclogp!(vi, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))) # r Turing.DEBUG && _debug("dist = $dist") Turing.DEBUG && _debug("vn = $vn") Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))") - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -532,9 +535,9 @@ function DynamicPPL.dot_assume( ) @assert length(dist) == size(var, 1) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:Hamiltonian}, @@ -544,9 +547,9 @@ function DynamicPPL.dot_assume( vi, ) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( @@ -606,11 +609,11 @@ function HMCState( vi = spl.state.vi # Link everything if needed. - !islinked(vi, spl) && link!(vi, spl) + link!(vi, spl, model) # Get the initial log pdf and gradient functions. - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = gen_logπ(vi, spl, model) + ∂logπ∂θ = gen_∂logπ∂θ(link(vi), spl, model) + logπ = gen_logπ(link(vi), spl, model) # Get the metric type. metricT = getmetricT(spl.alg) @@ -635,7 +638,7 @@ function HMCState( h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. # Unlink everything. - invlink!(vi, spl) + invlink!(vi, spl, model) return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) end diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 764c8959b0..393816678f 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -54,43 +54,6 @@ alg_str(::Sampler{<:MH}) = "MH" # Utility functions # ##################### -""" - set_namedtuple!(vi::VarInfo, nt::NamedTuple) - -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. -""" -function set_namedtuple!(vi::VarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - - n_vns = length(vns) - n_vals = length(vals) - v_isarr = vals isa AbstractArray - - if v_isarr && n_vals == 1 && n_vns > 1 - for (vn, val) in zip(vns, vals[1]) - vi[vn] = val isa AbstractArray ? val : [val] - end - elseif v_isarr && n_vals > 1 && n_vns == 1 - vi[vns[1]] = vals - elseif v_isarr && n_vals == n_vns > 1 - for (vn, val) in zip(vns, vals) - vi[vn] = [val] - end - elseif v_isarr && n_vals == 1 && n_vns == 1 - if vals[1] isa AbstractArray - vi[vns[1]] = vals[1] - else - vi[vns[1]] = [vals[1]] - end - elseif !(v_isarr) - vi[vns[1]] = [vals] - else - error("Cannot assign `NamedTuple` to `VarInfo`") - end - end -end - """ MHLogDensityFunction @@ -136,7 +99,7 @@ The second `NamedTuple` has model symbols as keys and their stored values as val """ function dist_val_tuple(spl::Sampler{<:MH}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) dt = _dist_tuple(spl.alg.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt @@ -149,7 +112,7 @@ end isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ - :($name = reconstruct(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)), + :($name = reconstruct(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)), DynamicPPL.getval(vi, vns.$name))) for name in names] return expr @@ -168,7 +131,7 @@ end :($name = props.$name) else # Otherwise, use the default proposal. - :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)))) + :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)))) end for name in names] return expr end @@ -229,8 +192,8 @@ function DynamicPPL.assume( vi, ) updategid!(vi, vn, spl) - r = vi[vn] - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + r = vi[vn, dist] + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -244,9 +207,9 @@ function DynamicPPL.dot_assume( getvn = i -> VarName(vn, vn.indexing * "[:,$i]") vns = getvn.(1:size(var, 2)) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:MH}, @@ -258,9 +221,9 @@ function DynamicPPL.dot_assume( getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") vns = getvn.(CartesianIndices(var)) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 7320f703db..a91a7d929d 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -26,7 +26,7 @@ include(dir*"/test/test_utils/AllUtils.jl") # run sampler: progress logging should be disabled and # it should return a Chains object - sampler = Sampler(HMC(0.1, 7), gdemo_default) + sampler = Turing.Sampler(HMC(0.1, 7), gdemo_default) chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) @test chains isa MCMCChains.Chains end @@ -56,10 +56,10 @@ include(dir*"/test/test_utils/AllUtils.jl") chn2_contd = sample(gdemo_default, alg2, 1000; resume_from=chn2) check_gdemo(chn2_contd) - chn3 = sample(gdemo_default, alg3, 1000; save_state=true) + chn3 = sample(gdemo_default, alg3, 5000; save_state=true) check_gdemo(chn3) - chn3_contd = sample(gdemo_default, alg3, 1000; resume_from=chn3) + chn3_contd = sample(gdemo_default, alg3, 5000; resume_from=chn3) check_gdemo(chn3_contd) end @testset "Contexts" begin diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index 295d12c3d9..18abcd0efa 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -134,7 +134,6 @@ include(dir*"/test/test_utils/AllUtils.jl") end end model = imm(randn(100), 1.0); - sample(model, Gibbs(MH(10, :z), HMC(0.01, 4, :m)), 100); sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m)), 100); end end diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index 6273e2ed47..563eb42bff 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -193,4 +193,70 @@ include(dir*"/test/test_utils/AllUtils.jl") @test sample(mwe(), HMC(0.2, 4), 1_000) isa Chains end + + @turing_testset "Stochastic support" begin + n = 10 + m = 10 + k = 4 + theta = randn(n) + b = zeros(k,m) + for i in 1:m + b[1,i] = randn() + for j in 2:k + dd = truncated(Normal(), b[j-1,i], Inf) + b[j,i] = rand(dd) + end + end + + logit = x -> log(x / (1 - x)) + invlogit = x -> exp(x)/(1 + exp(x)) + y = zeros(m,n) + probs = zeros(k,m,n) + for p in 1:n + for i in 1:m + probs[1,i,p] = 1.0 + for j in 1:(k-1) + Q = invlogit(theta[p] - b[j,i]) + probs[j,i,p] -= Q + probs[j+1,i,p] = Q + end + y[i,p] = rand(Categorical(probs[:,i,p])) + end + end + + # Graded Response Model + @model function grm(y, n, m, k, ::Type{TC}=Array{Float64,3}, ::Type{TM}=Array{Float64,2}, ::Type{TV}=Vector{Float64}) where {TC, TM, TV} + b = TM(undef, k, m) + for i in 1:m + b[1,i] ~ Normal(0,1) + for j in 2:k + b[j,i] ~ truncated(Normal(0,1), b[j-1,i], Inf) + end + end + probs = TC(undef, k, m, n) + theta = TV(undef, n) + for p in 1:n + theta[p] ~ Normal(0,1) + for i in 1:m + probs[1,i,p] = 1.0 + for j in 1:(k-1) + Q = invlogit(theta[p] - b[j,i]) + probs[j,i,p] -= Q + probs[j+1,i,p] = Q + end + probs[:,i,p] ./= sum(probs[:,i,p]) + y[i,p] ~ Categorical(probs[:,i,p], check_args=false) + end + end + return theta, b + end; + chn = sample(grm(y, n, m, k), HMC(0.05, 1), 100) + for c in 1:100 + for i in 1:m + for j in 2:k + @test chn["b[$j,$i]"].value[c] > chn["b[$(j-1),$i]"].value[c] + end + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 5776651bb6..3df25335b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,30 +14,31 @@ include("test_utils/AllUtils.jl") include("core/ad.jl") include("core/container.jl") end + @testset "inference" begin + @testset "samplers" begin + include("inference/gibbs.jl") + include("inference/is.jl") + include("inference/mh.jl") + include("inference/ess.jl") + include("inference/AdvancedSMC.jl") + include("inference/Inference.jl") - test_adbackends = if VERSION >= v"1.2" - [:forwarddiff, :tracker, :reversediff] - else - [:forwarddiff, :tracker] - end - Turing.setrdcache(false) - for adbackend in test_adbackends - Turing.setadbackend(adbackend) - @testset "inference: $adbackend" begin - @testset "samplers" begin - include("inference/gibbs.jl") - include("inference/hmc.jl") - include("inference/is.jl") - include("inference/mh.jl") - include("inference/ess.jl") - include("inference/AdvancedSMC.jl") - include("inference/Inference.jl") - include("contrib/inference/dynamichmc.jl") + test_adbackends = if VERSION >= v"1.2" + [:forwarddiff, :tracker, :reversediff] + else + [:forwarddiff, :tracker] + end + Turing.setrdcache(false) + for adbackend in test_adbackends + @testset "hmc: $adbackend" begin + Turing.setadbackend(adbackend) + include("inference/hmc.jl") + include("contrib/inference/dynamichmc.jl") + end + @testset "variational algorithms : $adbackend" begin + include("variational/advi.jl") + end end - end - - @testset "variational algorithms : $adbackend" begin - include("variational/advi.jl") end end @testset "variational optimisers" begin diff --git a/test/test_utils/testing_functions.jl b/test/test_utils/testing_functions.jl index a7b22eaf57..5678364900 100644 --- a/test/test_utils/testing_functions.jl +++ b/test/test_utils/testing_functions.jl @@ -19,7 +19,7 @@ function randr(vi::Turing.VarInfo, else if count Turing.checkindex(vn, vi, spl) end Turing.updategid!(vi, vn, spl) - return vi[vn] + return vi[vn, dist] end end From 8a36ed1986df34f423f82ffa497d0b2660e85d37 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 13 May 2020 12:28:37 +1000 Subject: [PATCH 02/15] test fixes --- src/inference/hmc.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index aae222808f..d3b35d6a43 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -411,12 +411,12 @@ function AbstractMCMC.step!( spl.state.eval_num = 0 Turing.DEBUG && @debug "current ϵ: $ϵ" + updategid!(spl.state.vi, spl) # When a Gibbs component if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." - updategid!(spl.state.vi, spl) link!(spl.state.vi, spl, model) model(link(spl.state.vi), spl) end @@ -517,6 +517,7 @@ function DynamicPPL.assume( vi, ) Turing.DEBUG && _debug("assuming...") + updategid!(vi, vn, spl) r = vi[vn, dist] # acclogp!(vi, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))) # r From 5f5420037c2501f016f13da0cbedd00b155d4c2c Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 15 May 2020 22:58:11 +1000 Subject: [PATCH 03/15] use forwarddiff in remaining tests --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3df25335b0..340e87ec6d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,11 +41,13 @@ include("test_utils/AllUtils.jl") end end end + + Turing.setadbackend(:forwarddiff) + @testset "variational optimisers" begin include("variational/optimisers.jl") end - Turing.setadbackend(:forwarddiff) @testset "stdlib" begin include("stdlib/distributions.jl") include("stdlib/RandomMeasures.jl") From 5b4936fd3ae7913ee88a0bf4611d5fadd7eacf26 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 15 May 2020 22:58:40 +1000 Subject: [PATCH 04/15] add remove_del! --- src/inference/AdvancedSMC.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 5acfeb5dd6..1554f29300 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -131,6 +131,7 @@ function AbstractMCMC.step!( ) # check that we received a real iteration number @assert iteration >= 1 "step! needs to be called with an 'iteration' keyword argument." + remove_del!(spl.state.vi, spl) # grab the weight pc = spl.state.particles @@ -237,6 +238,7 @@ function AbstractMCMC.step!( ) # obtain or create reference particle vi = spl.state.vi + remove_del!(vi, spl) ref_particle = isempty(vi) ? nothing : forkr(Trace(model, spl, vi)) # reset the VarInfo before new sweep From 05f1cf6304402be791a1f4a2a559611fe27571bc Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 17 May 2020 13:54:03 +1000 Subject: [PATCH 05/15] fix dynamic models - single var, different lengths --- src/inference/AdvancedSMC.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 1554f29300..26338ba03e 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -131,7 +131,6 @@ function AbstractMCMC.step!( ) # check that we received a real iteration number @assert iteration >= 1 "step! needs to be called with an 'iteration' keyword argument." - remove_del!(spl.state.vi, spl) # grab the weight pc = spl.state.particles @@ -238,7 +237,6 @@ function AbstractMCMC.step!( ) # obtain or create reference particle vi = spl.state.vi - remove_del!(vi, spl) ref_particle = isempty(vi) ? nothing : forkr(Trace(model, spl, vi)) # reset the VarInfo before new sweep @@ -321,11 +319,9 @@ function DynamicPPL.assume( r = rand(dist) push!(vi, vn, r, dist, spl) elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + DynamicPPL.removedel!(vi) r = rand(dist) - vi[vn, dist] = r - setgid!(vi, spl.selector, vn) - setorder!(vi, vn, get_num_produce(vi)) + push!(vi, vn, r, dist, spl) else updategid!(vi, vn, spl) r = vi[vn, dist] From f02316c24418045d39bdf14ead272f77b24e9017 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 17 May 2020 14:06:22 +1000 Subject: [PATCH 06/15] squashed commit - MixedVarInfo --- src/contrib/inference/dynamichmc.jl | 2 +- src/core/ad.jl | 10 +++++----- src/core/compat/reversediff.jl | 4 ++-- src/core/compat/zygote.jl | 2 +- src/inference/AdvancedSMC.jl | 4 ++-- src/inference/Inference.jl | 8 +++++++- src/inference/ess.jl | 2 +- src/inference/gibbs.jl | 18 +++--------------- src/inference/hmc.jl | 2 +- src/inference/is.jl | 2 +- src/inference/mh.jl | 6 +++--- test/core/ad.jl | 2 +- 12 files changed, 28 insertions(+), 34 deletions(-) diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 66829b01b5..68484a4f66 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -41,7 +41,7 @@ function DynamicNUTS{AD}(space::Symbol...) where AD DynamicNUTS{AD, space}() end -mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState +mutable struct DynamicNUTSState{V<:AbstractVarInfo, D} <: AbstractSamplerState vi::V draws::Vector{D} end diff --git a/src/core/ad.jl b/src/core/ad.jl index b896fccdda..7d86f76afa 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -60,7 +60,7 @@ getADbackend(spl::Sampler) = getADbackend(spl.alg) """ gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -71,7 +71,7 @@ tool is currently active. """ function gradient_logp( θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::Sampler ) @@ -82,7 +82,7 @@ end gradient_logp( backend::ADBackend, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -93,7 +93,7 @@ specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{ function gradient_logp( ::ForwardDiffAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler=SampleFromPrior(), ) @@ -120,7 +120,7 @@ end function gradient_logp( ::TrackerAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/core/compat/reversediff.jl b/src/core/compat/reversediff.jl index 5ccfbbebd5..f3822d35ee 100644 --- a/src/core/compat/reversediff.jl +++ b/src/core/compat/reversediff.jl @@ -17,7 +17,7 @@ end function gradient_logp( backend::ReverseDiffAD{false}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) @@ -54,7 +54,7 @@ end function gradient_logp( backend::ReverseDiffAD{true}, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/core/compat/zygote.jl b/src/core/compat/zygote.jl index 3c56a1922c..dc18fa0f8f 100644 --- a/src/core/compat/zygote.jl +++ b/src/core/compat/zygote.jl @@ -7,7 +7,7 @@ end function gradient_logp( backend::ZygoteAD, θ::AbstractVector{<:Real}, - vi::VarInfo, + vi::AbstractVarInfo, model::Model, sampler::AbstractSampler = SampleFromPrior(), ) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 26338ba03e..35729d9b0b 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -69,7 +69,7 @@ SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, sp SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space) -mutable struct SMCState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F @@ -203,7 +203,7 @@ function PG(nparticles::Int, space::Tuple) return PG(nparticles, Turing.Core.ResampleWithESSThreshold(), space) end -mutable struct PGState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V # The logevidence after aggregating all samples together. average_logevidence :: F diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index c4c8e86c05..bb8b221d7c 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -74,6 +74,12 @@ getADbackend(::Hamiltonian{AD}) where AD = AD() # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end +getindex(vi::MixedVarInfo, spl::Sampler{<:Hamiltonian}) = vi.tvi[spl] +function setindex!(vi::MixedVarInfo, val, spl::Sampler{<:Hamiltonian}) + vi.tvi[spl] = val + return vi +end + """ mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real) @@ -519,7 +525,7 @@ end """ A blank `AbstractSamplerState` that contains only `VarInfo` information. """ -mutable struct SamplerState{VIType<:VarInfo} <: AbstractSamplerState +mutable struct SamplerState{VIType<:AbstractVarInfo} <: AbstractSamplerState vi :: VIType end diff --git a/src/inference/ess.jl b/src/inference/ess.jl index c649123e61..755a962485 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -25,7 +25,7 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() -mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState +mutable struct ESSState{V<:AbstractVarInfo} <: AbstractSamplerState vi::V end diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 4b82cb934a..9482ced94b 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -42,12 +42,12 @@ function Gibbs(algs::GibbsComponent...) end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that the `Gibbs` sampler iterates through for each `step!`. """ -mutable struct GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState +mutable struct GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} <: AbstractSamplerState vi::V samplers::S end @@ -81,19 +81,7 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) # add Gibbs to gids for all variables vi = spl.state.vi - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns - - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) - - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end - end + DynamicPPL.updategid!(vi, (spl, samplers...)) return spl end diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index d3b35d6a43..b82c9a722f 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -3,7 +3,7 @@ ### mutable struct HMCState{ - TV <: TypedVarInfo, + TV <: AbstractVarInfo, TTraj<:AHMC.AbstractTrajectory, TAdapt<:AHMC.Adaptation.AbstractAdaptor, PhType <: AHMC.PhasePoint diff --git a/src/inference/is.jl b/src/inference/is.jl index a7f515e364..dcd4cad5b6 100644 --- a/src/inference/is.jl +++ b/src/inference/is.jl @@ -31,7 +31,7 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() -mutable struct ISState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState +mutable struct ISState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerState vi :: V final_logevidence :: F end diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 393816678f..02905b1f59 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -59,7 +59,7 @@ alg_str(::Sampler{<:MH}) = "MH" A log density function for the MH sampler. -This variant uses the `set_namedtuple!` function to update the `VarInfo`. +This variant uses the `set_namedtuple!` function to update the variables. """ struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH}} <: Function # Relax AMH.DensityModel? model::M @@ -106,7 +106,7 @@ function dist_val_tuple(spl::Sampler{<:MH}) end @generated function _val_tuple( - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names} isempty(names) === 0 && return :(NamedTuple()) @@ -120,7 +120,7 @@ end @generated function _dist_tuple( props::NamedTuple{propnames}, - vi::VarInfo, + vi, vns::NamedTuple{names} ) where {names,propnames} isempty(names) === 0 && return :(NamedTuple()) diff --git a/test/core/ad.jl b/test/core/ad.jl index 77cb06e0c6..bbe8e46019 100644 --- a/test/core/ad.jl +++ b/test/core/ad.jl @@ -1,5 +1,5 @@ using ForwardDiff, Distributions, FiniteDifferences, Tracker, Random, LinearAlgebra -using PDMats, Zygote +using PDMats, Zygote, ReverseDiff using Turing: Turing, invlink, link, SampleFromPrior, TrackerAD, ZygoteAD using DynamicPPL: getval From b8bb10c4fcc7642c11408806effbeaa53e9ec202 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 21 May 2020 02:22:59 +1000 Subject: [PATCH 07/15] many fixes --- Project.toml | 1 + src/contrib/inference/dynamichmc.jl | 2 +- src/inference/AdvancedSMC.jl | 25 +++++++++++++ src/inference/Inference.jl | 25 ++++++++++++- src/inference/ess.jl | 3 ++ src/inference/gibbs.jl | 52 +++++++++++++++++++++----- src/inference/hmc.jl | 58 ++++++++++++++++++++++------- src/variational/advi.jl | 2 +- test/core/ad.jl | 4 +- test/inference/Inference.jl | 37 ++++++++++++++++++ test/inference/gibbs.jl | 15 +++++--- test/runtests.jl | 10 ++++- test/test_utils/ad_utils.jl | 2 +- test/test_utils/numerical_tests.jl | 13 +++++-- 14 files changed, 211 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index 716609900c..eca8770db9 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.0" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 68484a4f66..0cdac493b5 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -97,7 +97,7 @@ function Sampler( s::Selector=Selector() ) # Construct a state, using a default function. - state = DynamicNUTSState(VarInfo(model), []) + state = DynamicNUTSState(DynamicPPL.TypedVarInfo(model), []) # Return a new sampler. return Sampler(alg, Dict{Symbol,Any}(), s, state) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 35729d9b0b..ce857fd121 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -22,6 +22,27 @@ struct ParticleTransition{T, F<:AbstractFloat} weight::F end +function Base.promote_type( + ::Type{ParticleTransition{T1, F1}}, + ::Type{ParticleTransition{T2, F2}}, +) where {T1, F1, T2, F2} + return ParticleTransition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{ParticleTransition{T, F}}, + t::ParticleTransition, +) where {T, F} + return ParticleTransition{T, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(F, t.le), + convert(F, t.weight), + ) +end + function additional_parameters(::Type{<:ParticleTransition}) return [:lp,:le, :weight] end @@ -214,6 +235,10 @@ function PGState(model::Model) return PGState(vi, 0.0) end +function replace_varinfo(s::PGState, vi::AbstractVarInfo) + return PGState(vi, s.average_logevidence) +end + const CSMC = PG # type alias of PG as Conditional SMC """ diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index bb8b221d7c..189151a369 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -4,7 +4,7 @@ using ..Core using ..Core: logZ using ..Utilities using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, set_namedtuple!, - islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, + islinked_and_trans, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, settrans!, getvns, getinitdist, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, @@ -20,6 +20,7 @@ using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler using Bijectors: _debug using DocStringExtensions: TYPEDEF, TYPEDFIELDS +import BangBang import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC @@ -107,6 +108,22 @@ struct Transition{T, F<:AbstractFloat} lp :: F end +function Base.promote_type( + ::Type{Transition{T1, F1}}, + ::Type{Transition{T2, F2}}, +) where {T1, F1, T2, F2} + return Transition{ + Union{T1, T2}, + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{Transition{T, F}}, + t::Transition, +) where {T, F} + return Transition{T, F}(convert(T, t.θ), convert(F, t.lp)) +end + function Transition(spl::Sampler, nt::NamedTuple=NamedTuple()) theta = merge(tonamedtuple(spl.state.vi), nt) lp = getlogp(spl.state.vi) @@ -334,7 +351,8 @@ function flatten_namedtuple(nt::NamedTuple) if length(v) == 1 return [(string(k), v)] else - return mapreduce(vcat, zip(v[1], v[2])) do (vnval, vn) + init = Tuple{String, eltype(v[1])}[] + return mapreduce(vcat, zip(v[1], v[2]); init = init) do (vnval, vn) return collect(FlattenIterator(vn, vnval)) end end @@ -528,6 +546,9 @@ A blank `AbstractSamplerState` that contains only `VarInfo` information. mutable struct SamplerState{VIType<:AbstractVarInfo} <: AbstractSamplerState vi :: VIType end +function replace_varinfo(::SamplerState, vi::AbstractVarInfo) + return SamplerState(vi) +end ####################################### # Concrete algorithm implementations. # diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 755a962485..a790c51cdd 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -28,6 +28,9 @@ ESS(space::Symbol) = ESS{(space,)}() mutable struct ESSState{V<:AbstractVarInfo} <: AbstractSamplerState vi::V end +function replace_varinfo(::ESSState, vi::AbstractVarInfo) + return ESSState(vi) +end function Sampler(alg::ESS, model::Model, s::Selector) # sanity check diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 9482ced94b..d327e099b1 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -56,6 +56,10 @@ function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}) return GibbsState(VarInfo(model), samplers) end +function replace_varinfo(s::GibbsState, vi::AbstractVarInfo) + return GibbsState(vi, s.samplers) +end + function Sampler(alg::Gibbs, model::Model, s::Selector) # sanity check for space space = getspace(alg) @@ -72,16 +76,24 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) selector = Selector(Symbol(typeof(_alg)), rerun) Sampler(_alg, model, selector) end + varinfo = merge(ntuple(i -> samplers[i].state.vi, Val(length(samplers)))...) + samplers = map(samplers) do sampler + Sampler( + sampler.alg, + sampler.info, + sampler.selector, + replace_varinfo(sampler.state, varinfo), + ) + end # create a state variable - state = GibbsState(model, samplers) + state = GibbsState(varinfo, samplers) # create the sampler info = Dict{Symbol, Any}() spl = Sampler(alg, info, s, state) # add Gibbs to gids for all variables - vi = spl.state.vi - DynamicPPL.updategid!(vi, (spl, samplers...)) + DynamicPPL.updategid!(varinfo, (spl, samplers...)) return spl end @@ -100,6 +112,27 @@ struct GibbsTransition{T,F,S<:AbstractVector} transitions::S end +function Base.promote_type( + ::Type{GibbsTransition{T1, F1, S1}}, + ::Type{GibbsTransition{T2, F2, S2}}, +) where {T1, F1, S1, T2, F2, S2} + return GibbsTransition{ + Union{T1, T2}, + promote_type(F1, F2), + promote_type(S1, S2), + } +end +function Base.convert( + ::Type{GibbsTransition{T, F, S}}, + t::GibbsTransition, +) where {T, F, S} + return GibbsTransition{T, F, S}( + convert(T, t.θ), + convert(F, t.lp), + convert(S, t.transitions), + ) +end + function GibbsTransition(spl::Sampler{<:Gibbs}, transitions::AbstractVector) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -176,25 +209,26 @@ function AbstractMCMC.step!( end # Do not store transitions of subsamplers -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition::GibbsTransition, ::Model, ::Sampler{<:Gibbs}, N::Integer; kwargs... ) - return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N) + ts = Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, 0) + sizehint!(ts, N) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::Vector{<:Transition}, - iteration::Integer, transition::GibbsTransition, + iteration::Integer, ::Model, ::Sampler{<:Gibbs}, ::Integer; kwargs... ) - transitions[iteration] = Transition(transition.θ, transition.lp) - return + return BangBang.push!!(transitions, Transition(transition.θ, transition.lp)) end diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index b82c9a722f..9324f7f800 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -17,6 +17,18 @@ mutable struct HMCState{ z :: PhType end +function replace_varinfo(s::HMCState, vi::AbstractVarInfo) + return HMCState( + vi, + s.eval_num, + s.i, + s.traj, + s.h, + s.adaptor, + s.z, + ) +end + ########################## # Hamiltonian Transition # ########################## @@ -27,6 +39,27 @@ struct HamiltonianTransition{T, NT<:NamedTuple, F<:AbstractFloat} stat :: NT end +function Base.promote_type( + ::Type{HamiltonianTransition{T1, NT1, F1}}, + ::Type{HamiltonianTransition{T2, NT2, F2}}, +) where {T1, NT1, F1, T2, NT2, F2} + return HamiltonianTransition{ + Union{T1, T2}, + promote_type(NT1, NT2), + promote_type(F1, F2), + } +end +function Base.convert( + ::Type{HamiltonianTransition{T, NT, F}}, + t::HamiltonianTransition, +) where {T, NT, F} + return HamiltonianTransition{T, NT, F}( + convert(T, t.θ), + convert(F, t.lp), + convert(NT, t.stat) + ) +end + function HamiltonianTransition(spl::Sampler{<:Hamiltonian}, t::AHMC.Transition) theta = tonamedtuple(spl.state.vi) lp = getlogp(spl.state.vi) @@ -178,7 +211,7 @@ function AbstractMCMC.sample_init!( end end -function AbstractMCMC.transitions_init( +function AbstractMCMC.transitions( transition, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, @@ -191,28 +224,26 @@ function AbstractMCMC.transitions_init( else n = N end - return Vector{typeof(transition)}(undef, n) + ts = Vector{typeof(transition)}(undef, 0) + sizehint!(ts, n) + return ts end -function AbstractMCMC.transitions_save!( +function AbstractMCMC.save!!( transitions::AbstractVector, - iteration::Integer, transition, + iteration::Integer, ::AbstractModel, sampler::Sampler{<:Hamiltonian}, ::Integer; discard_adapt = true, kwargs... ) - if discard_adapt && isdefined(sampler.alg, :n_adapts) - if iteration > sampler.alg.n_adapts - transitions[iteration - sampler.alg.n_adapts] = transition - end - return + if discard_adapt && isdefined(sampler.alg, :n_adapts) && iteration <= sampler.alg.n_adapts + return transitions + else + return BangBang.push!!(transitions, transition) end - - transitions[iteration] = transition - return end """ @@ -375,7 +406,8 @@ function Sampler( ) info = Dict{Symbol, Any}() # Create an empty sampler state that just holds a typed VarInfo. - initial_state = SamplerState(VarInfo(model)) + varinfo = getspace(alg) === () ? TypedVarInfo(model) : VarInfo(model) + initial_state = SamplerState(varinfo) # Create an initial sampler, to get all the initialization out of the way. initial_spl = Sampler(alg, info, s, initial_state) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index c8e58e6c50..26c861a090 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -62,7 +62,7 @@ Creates a mean-field approximation with multivariate normal as underlying distri meanfield(model::Model) = meanfield(GLOBAL_RNG, model) function meanfield(rng::AbstractRNG, model::Model) # setup - varinfo = Turing.VarInfo(model) + varinfo = DynamicPPL.TypedVarInfo(model) num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)]) diff --git a/test/core/ad.jl b/test/core/ad.jl index bbe8e46019..21074b777c 100644 --- a/test/core/ad.jl +++ b/test/core/ad.jl @@ -18,8 +18,8 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...) ad_test_f = gdemo_default vi = Turing.VarInfo(ad_test_f) ad_test_f(vi, SampleFromPrior()) - svn = vi.metadata.s.vns[1] - mvn = vi.metadata.m.vns[1] + svn = vi.tvi.metadata.s.vns[1] + mvn = vi.tvi.metadata.m.vns[1] _s = getval(vi, svn)[1] _m = getval(vi, mvn)[1] diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index a91a7d929d..806f77816d 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -7,7 +7,21 @@ using Random dir = splitdir(splitdir(pathof(Turing))[1])[1] include(dir*"/test/test_utils/AllUtils.jl") +struct DynamicDist <: DiscreteMultivariateDistribution end +function Distributions.logpdf(::DynamicDist, dsl_numeric::AbstractVector{Int}) + return sum([log(0.5) * 0.5^i for i in 1:length(dsl_numeric)]) +end +function Random.rand(rng::Random.AbstractRNG, ::DynamicDist) + fst = rand(rng, [0, 1]) + dsl_numeric = [fst] + while rand() < 0.5 + push!(dsl_numeric, rand(rng, [0, 1])) + end + return dsl_numeric +end + @testset "io.jl" begin + #= # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin @@ -114,4 +128,27 @@ include(dir*"/test/test_utils/AllUtils.jl") @test mean(x[:s][1] for x in chains) ≈ 3 atol=0.1 @test mean(x[:m][1] for x in chains) ≈ 0 atol=0.1 end + =# + @testset "stochastic control flow" begin + @model demo(p) = begin + x ~ Categorical(p) + if x == 1 + y ~ Normal() + elseif x == 2 + z ~ Normal() + else + k ~ Normal() + end + end + chain = sample(demo(fill(1/3, 3)), PG(4), 7000) + check_numerical(chain, [:x, :y, :z, :k], [2, 0, 0, 0], atol=0.05, skip_missing=true) + + chain = sample(demo(fill(1/3, 3)), Gibbs(PG(4, :x, :y), PG(4, :z, :k)), 7000) + check_numerical(chain, [:x, :y, :z, :k], [2, 0, 0, 0], atol=0.05, skip_missing=true) + + @model function mwe() + dsl ~ DynamicDist() + end + chain = sample(mwe(), PG(10), 500) + end end diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index 18abcd0efa..391fc63b81 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -40,7 +40,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(100) alg = Gibbs( CSMC(10, :s), - HMC(0.2, 4, :m)) + HMC(0.2, 4, :m), + ) chain = sample(gdemo(1.5, 2.0), alg, 3000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) @@ -48,13 +49,15 @@ include(dir*"/test/test_utils/AllUtils.jl") alg = Gibbs( MH(:s), - HMC(0.2, 4, :m)) + HMC(0.2, 4, :m), + ) chain = sample(gdemo(1.5, 2.0), alg, 5000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) alg = Gibbs( CSMC(15, :s), - ESS(:m)) + ESS(:m), + ) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) @@ -67,7 +70,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(200) gibbs = Gibbs( PG(10, :z1, :z2, :z3, :z4), - HMC(0.15, 3, :mu1, :mu2)) + HMC(0.15, 3, :mu1, :mu2), + ) chain = sample(MoGtest_default, gibbs, 1500) check_MoGtest_default(chain, atol = 0.15) @@ -76,7 +80,8 @@ include(dir*"/test/test_utils/AllUtils.jl") Random.seed!(200) gibbs = Gibbs( PG(10, :z1, :z2, :z3, :z4), - ESS(:mu1), ESS(:mu2)) + ESS(:mu1), ESS(:mu2), + ) chain = sample(MoGtest_default, gibbs, 1500) check_MoGtest_default(chain, atol = 0.15) end diff --git a/test/runtests.jl b/test/runtests.jl index 340e87ec6d..6c0c0b0ae5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,19 +10,24 @@ include("test_utils/AllUtils.jl") # Begin testing. @testset "Turing" begin + #= @testset "core" begin include("core/ad.jl") include("core/container.jl") end + =# + Turing.setadbackend(:forwarddiff) @testset "inference" begin @testset "samplers" begin include("inference/gibbs.jl") + #= include("inference/is.jl") include("inference/mh.jl") include("inference/ess.jl") include("inference/AdvancedSMC.jl") + =# include("inference/Inference.jl") - + #= test_adbackends = if VERSION >= v"1.2" [:forwarddiff, :tracker, :reversediff] else @@ -39,9 +44,11 @@ include("test_utils/AllUtils.jl") include("variational/advi.jl") end end + =# end end + #= Turing.setadbackend(:forwarddiff) @testset "variational optimisers" begin @@ -56,4 +63,5 @@ include("test_utils/AllUtils.jl") @testset "utilities" begin # include("utilities/stan-interface.jl") end + =# end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index bd684918fb..0582082304 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -78,7 +78,7 @@ function test_model_ad(model, f, syms::Vector{Symbol}) vnvals = Vector{Float64}() for i in 1:length(syms) s = syms[i] - vnms[i] = getfield(vi.metadata, s).vns[1] + vnms[i] = getfield(vi.tvi.metadata, s).vns[1] vals = getval(vi, vnms[i]) for i in eachindex(vals) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 88e2b4e9e7..0765a0a27a 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -41,12 +41,19 @@ end function check_numerical(chain, symbols::Vector, exact_vals::Vector; + skip_missing=false, atol=0.2, rtol=0.0) for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? - mean(chain[sym].value) : - vec(mean(chain[sym].value, dims=[1])) + if skip_missing + E = val isa Real ? + mean(skipmissing(chain[sym].value)) : + vec(mean(skipmissing(chain[sym].value), dims=[1])) + else + E = val isa Real ? + mean(chain[sym].value) : + vec(mean(chain[sym].value, dims=[1])) + end @info (symbol=sym, exact=val, evaluated=E) @test E ≈ val atol=atol rtol=rtol end From 637fa75baa0914e7532f310cfc378eddbc19235d Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 21 May 2020 02:24:36 +1000 Subject: [PATCH 08/15] uncomment tests --- test/inference/Inference.jl | 2 -- test/runtests.jl | 8 -------- 2 files changed, 10 deletions(-) diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 806f77816d..ecd6a1d376 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -21,7 +21,6 @@ function Random.rand(rng::Random.AbstractRNG, ::DynamicDist) end @testset "io.jl" begin - #= # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin @@ -128,7 +127,6 @@ end @test mean(x[:s][1] for x in chains) ≈ 3 atol=0.1 @test mean(x[:m][1] for x in chains) ≈ 0 atol=0.1 end - =# @testset "stochastic control flow" begin @model demo(p) = begin x ~ Categorical(p) diff --git a/test/runtests.jl b/test/runtests.jl index 6c0c0b0ae5..f119d12224 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,24 +10,19 @@ include("test_utils/AllUtils.jl") # Begin testing. @testset "Turing" begin - #= @testset "core" begin include("core/ad.jl") include("core/container.jl") end - =# Turing.setadbackend(:forwarddiff) @testset "inference" begin @testset "samplers" begin include("inference/gibbs.jl") - #= include("inference/is.jl") include("inference/mh.jl") include("inference/ess.jl") include("inference/AdvancedSMC.jl") - =# include("inference/Inference.jl") - #= test_adbackends = if VERSION >= v"1.2" [:forwarddiff, :tracker, :reversediff] else @@ -44,11 +39,9 @@ include("test_utils/AllUtils.jl") include("variational/advi.jl") end end - =# end end - #= Turing.setadbackend(:forwarddiff) @testset "variational optimisers" begin @@ -63,5 +56,4 @@ include("test_utils/AllUtils.jl") @testset "utilities" begin # include("utilities/stan-interface.jl") end - =# end From 30a1f64375d610be5bb97a32680ccf24f95da76b Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 24 May 2020 02:54:09 +1000 Subject: [PATCH 09/15] add specialize_from kwarg --- src/contrib/inference/AdvancedSMCExtensions.jl | 11 ++++++----- src/inference/AdvancedSMC.jl | 16 ++++++++-------- src/inference/Inference.jl | 6 ++++-- src/inference/ess.jl | 4 ++-- src/inference/gibbs.jl | 8 ++++---- src/inference/hmc.jl | 5 +++-- 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/contrib/inference/AdvancedSMCExtensions.jl b/src/contrib/inference/AdvancedSMCExtensions.jl index 7adc20eacd..8516d99d5c 100644 --- a/src/contrib/inference/AdvancedSMCExtensions.jl +++ b/src/contrib/inference/AdvancedSMCExtensions.jl @@ -39,7 +39,7 @@ end PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), ()) -function Sampler(alg::PMMH, model::Model, s::Selector) +function Sampler(alg::PMMH, model::Model, s::Selector; specialize_after=1) info = Dict{Symbol, Any}() spl = Sampler(alg, info, s) @@ -118,7 +118,8 @@ function sample( model::Model, alg::PMMH; save_state=false, # flag for state saving resume_from=nothing, # chain to continue - reuse_spl_n=0 # flag for spl re-using + reuse_spl_n=0, # flag for spl re-using + specialize_after=1 ) spl = Sampler(alg, model) @@ -140,7 +141,7 @@ function sample( model::Model, # Init parameters vi = if resume_from === nothing - vi_ = VarInfo(model) + vi_ = VarInfo(model, specialize_after) else resume_from.info[:vi] end @@ -279,7 +280,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first: VarInfos[nodes_permutation] end -function sample(model::Model, alg::IPMCMC) +function sample(model::Model, alg::IPMCMC; specialize_after=1) spl = Sampler(alg) @@ -295,7 +296,7 @@ function sample(model::Model, alg::IPMCMC) end # Init parameters - vi = empty!(VarInfo(model)) + vi = empty!(VarInfo(model, specialize_after)) VarInfos = Array{VarInfo}(undef, spl.alg.n_nodes) for j in 1:spl.alg.n_nodes VarInfos[j] = deepcopy(vi) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index ce857fd121..9718bc5280 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -97,16 +97,16 @@ mutable struct SMCState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSampler particles :: ParticleContainer end -function SMCState(model::Model) - vi = VarInfo(model) +function SMCState(model::Model; specialize_after=1) + vi = VarInfo(model, specialize_after) particles = ParticleContainer(Trace[]) return SMCState(vi, 0.0, particles) end -function Sampler(alg::SMC, model::Model, s::Selector) +function Sampler(alg::SMC, model::Model, s::Selector; specialize_after=1) dict = Dict{Symbol, Any}() - state = SMCState(model) + state = SMCState(model; specialize_after=specialize_after) return Sampler(alg, dict, s, state) end @@ -230,8 +230,8 @@ mutable struct PGState{V<:AbstractVarInfo, F<:AbstractFloat} <: AbstractSamplerS average_logevidence :: F end -function PGState(model::Model) - vi = VarInfo(model) +function PGState(model::Model; specialize_after=1) + vi = VarInfo(model, specialize_after) return PGState(vi, 0.0) end @@ -246,9 +246,9 @@ const CSMC = PG # type alias of PG as Conditional SMC Return a `Sampler` object for the PG algorithm. """ -function Sampler(alg::PG, model::Model, s::Selector) +function Sampler(alg::PG, model::Model, s::Selector; specialize_after=1) info = Dict{Symbol, Any}() - state = PGState(model) + state = PGState(model; specialize_after=specialize_after) return Sampler(alg, info, s, state) end diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 189151a369..9c6353c439 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -179,9 +179,10 @@ function AbstractMCMC.sample( model::AbstractModel, alg::InferenceAlgorithm, N::Integer; + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...) + return AbstractMCMC.sample(rng, model, Sampler(alg, model; specialize_after=specialize_after), N; kwargs...) end function AbstractMCMC.sample( @@ -239,9 +240,10 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, n_chains::Integer; + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; + return AbstractMCMC.sample(rng, model, Sampler(alg, model; specialize_after=specialize_after), parallel, N, n_chains; kwargs...) end diff --git a/src/inference/ess.jl b/src/inference/ess.jl index a790c51cdd..238a9f5b79 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -32,9 +32,9 @@ function replace_varinfo(::ESSState, vi::AbstractVarInfo) return ESSState(vi) end -function Sampler(alg::ESS, model::Model, s::Selector) +function Sampler(alg::ESS, model::Model, s::Selector; specialize_after=1) # sanity check - vi = VarInfo(model) + vi = VarInfo(model, specialize_after) space = getspace(alg) vns = getvns(vi, s, Val(space)) length(vns) == 1 || diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index d327e099b1..94b2ba9390 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -52,15 +52,15 @@ mutable struct GibbsState{V<:AbstractVarInfo, S<:Tuple{Vararg{Sampler}}} <: Abst samplers::S end -function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}) - return GibbsState(VarInfo(model), samplers) +function GibbsState(model::Model, samplers::Tuple{Vararg{Sampler}}; specialize_after=1) + return GibbsState(VarInfo(model, specialize_after), samplers) end function replace_varinfo(s::GibbsState, vi::AbstractVarInfo) return GibbsState(vi, s.samplers) end -function Sampler(alg::Gibbs, model::Model, s::Selector) +function Sampler(alg::Gibbs, model::Model, s::Selector; specialize_after=1) # sanity check for space space = getspace(alg) # create tuple of samplers @@ -74,7 +74,7 @@ function Sampler(alg::Gibbs, model::Model, s::Selector) end rerun = !(_alg isa MH) || prev_alg isa PG || prev_alg isa ESS selector = Selector(Symbol(typeof(_alg)), rerun) - Sampler(_alg, model, selector) + Sampler(_alg, model, selector; specialize_after=specialize_after) end varinfo = merge(ntuple(i -> samplers[i].state.vi, Val(length(samplers)))...) samplers = map(samplers) do sampler diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 9324f7f800..405b15909c 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -402,11 +402,12 @@ end function Sampler( alg::Union{StaticHamiltonian, AdaptiveHamiltonian}, model::Model, - s::Selector=Selector() + s::Selector=Selector(); + specialize_after=1 ) info = Dict{Symbol, Any}() # Create an empty sampler state that just holds a typed VarInfo. - varinfo = getspace(alg) === () ? TypedVarInfo(model) : VarInfo(model) + varinfo = getspace(alg) === () ? TypedVarInfo(model) : VarInfo(model, specialize_after) initial_state = SamplerState(varinfo) # Create an initial sampler, to get all the initialization out of the way. From e0241bd9e829e7c316829ec0c187edae80f69e1b Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 27 May 2020 03:17:43 +1000 Subject: [PATCH 10/15] UntypedVarInfo fix --- src/inference/Inference.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 9c6353c439..9d7be02789 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -576,6 +576,7 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) end floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) +floatof(::Type{Real}) = Real floatof(::Type) = Real # fallback if type inference failed function get_matching_type( From 90cf9fb0e6d99911c49f20a2993ab5b95ff592e7 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 29 May 2020 00:22:50 +1000 Subject: [PATCH 11/15] some fixes --- src/inference/Inference.jl | 15 +++++++++++++-- src/inference/mh.jl | 7 ++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 9d7be02789..02e35c902b 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -271,10 +271,21 @@ function AbstractMCMC.sample( n_chains::Integer; chain_type=MCMCChains.Chains, progress=PROGRESS[], + specialize_after=1, kwargs... ) - return AbstractMCMC.sample(rng, model, SampleFromPrior(), parallel, N, n_chains; - chain_type=chain_type, progress=progress, kwargs...) + vi = VarInfo(model, specialize_after) + return AbstractMCMC.sample( + rng, + model, + SampleFromPrior(vi), + parallel, + N, + n_chains; + chain_type=chain_type, + progress=progress, + kwargs..., + ) end function AbstractMCMC.sample_init!( diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 02905b1f59..993f90b95f 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -36,13 +36,14 @@ end function Sampler( alg::MH, model::Model, - s::Selector=Selector() + s::Selector=Selector(); + specialize_after=1, ) # Set up info dict. info = Dict{Symbol, Any}() # Set up state struct. - state = SamplerState(VarInfo(model)) + state = SamplerState(VarInfo(model, specialize_after)) # Generate a sampler. return Sampler(alg, info, s, state) @@ -98,7 +99,7 @@ Returns two `NamedTuples`. The first `NamedTuple` has symbols as keys and distri The second `NamedTuple` has model symbols as keys and their stored values as values. """ function dist_val_tuple(spl::Sampler{<:MH}) - vi = spl.state.vi + vi = TypedVarInfo(spl.state.vi) vns = getvns(vi, spl) dt = _dist_tuple(spl.alg.proposals, vi, vns) vt = _val_tuple(vi, vns) From c8a8b99385f3ad4d8d7d0ca2ceb1175495c587fe Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 29 May 2020 18:12:24 +1000 Subject: [PATCH 12/15] use first VarInfo for gibbs --- src/inference/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 94b2ba9390..46db2cf078 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -76,7 +76,7 @@ function Sampler(alg::Gibbs, model::Model, s::Selector; specialize_after=1) selector = Selector(Symbol(typeof(_alg)), rerun) Sampler(_alg, model, selector; specialize_after=specialize_after) end - varinfo = merge(ntuple(i -> samplers[i].state.vi, Val(length(samplers)))...) + varinfo = samplers[1].state.vi samplers = map(samplers) do sampler Sampler( sampler.alg, From a959a6a5a8e663164bc0c39ba8bb3fcd8f7c4bed Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 30 May 2020 12:30:19 +1000 Subject: [PATCH 13/15] only remove deleted when length is different --- src/inference/AdvancedSMC.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index 9718bc5280..ff4c2ee1bf 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -344,9 +344,15 @@ function DynamicPPL.assume( r = rand(dist) push!(vi, vn, r, dist, spl) elseif is_flagged(vi, vn, "del") - DynamicPPL.removedel!(vi) r = rand(dist) - push!(vi, vn, r, dist, spl) + if length(vi[vn, dist]) == length(r) + vi[vn, dist] = r + unset_flag!(vi, vn, "del") + updategid!(vi, vn, spl) + else + DynamicPPL.removedel!(vi) + push!(vi, vn, r, dist, spl) + end else updategid!(vi, vn, spl) r = vi[vn, dist] From 4e80a48073078b8a88e2aefcbf3c196a08eef08a Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Tue, 2 Jun 2020 11:34:25 +1000 Subject: [PATCH 14/15] comment out automatic TArray conversion --- src/inference/Inference.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 02e35c902b..3c7a767802 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -618,6 +618,7 @@ function get_matching_type( ) where {T, N, TV <: Array{T, N}} return Array{get_matching_type(spl, vi, T), N} end +#= function get_matching_type( spl::Sampler{<:Union{PG, SMC}}, vi, @@ -625,6 +626,7 @@ function get_matching_type( ) where {T, N, TV <: Array{T, N}} return TArray{T, N} end +=# ############## # Utilities # From 79934332b6c983fe7ea39c4c82a347eb92547335 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 3 Jun 2020 01:09:58 +1000 Subject: [PATCH 15/15] allow `UntypedVarInfo` to be used with pure HMC --- src/inference/hmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 405b15909c..1d827005a2 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -407,7 +407,7 @@ function Sampler( ) info = Dict{Symbol, Any}() # Create an empty sampler state that just holds a typed VarInfo. - varinfo = getspace(alg) === () ? TypedVarInfo(model) : VarInfo(model, specialize_after) + varinfo = getspace(alg) === () && specialize_after > 0 ? TypedVarInfo(model) : VarInfo(model, specialize_after) initial_state = SamplerState(varinfo) # Create an initial sampler, to get all the initialization out of the way.