diff --git a/HISTORY.md b/HISTORY.md index 90864508b..abefc1e36 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. `loadstate` is exported from DynamicPPL. -### Change of default keytype of `pointwise_logdensities` +### Change of output type for `pointwise_logdensities` -The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`. +The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices. + +If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`. **Other changes** diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7886ad468..771dd664f 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -292,4 +292,290 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha end end +""" + DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains + ::Val{whichlogprob}=Val(:both), + ) + +Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where +the log-density of each variable at each sample is stored (rather than its value). + +`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. + +You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName, +Matrix{Float64}}` instead. + +See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), +[`DynamicPPL.pointwise_prior_logdensities`](@ref). + +# Examples + +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end +demo (generic function with 2 methods) + +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); + +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m]; + info=(varname_to_symbol=Dict( + @varname(s) => :s, + @varname(m) => :m, + ),), + ); + +julia> plds = pointwise_logdensities(model, chain) +Chains MCMC chain (3×6×1 Array{Float64, 3}): + +Iterations = 1:1:3 +Number of chains = 1 +Samples per chain = 3 +parameters = s, m, xs[1], xs[2], xs[3], y +[...] + +julia> plds[:s] +2-dimensional AxisArray{Float64,2,...} with axes: + :iter, 1:1:3 + :chain, 1:1 +And data, a 3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 + +julia> # The above is the same as: + logpdf.(InverseGamma(2, 3), chain[:s]) +3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 +``` + +julia> # Alternatively: + plds_dict = pointwise_logdensities(model, chain, OrderedDict) +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] +""" +function DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains, + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob,Tout} + vi = DynamicPPL.VarInfo(model) + acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() + accname = DynamicPPL.accumulator_name(acc) + vi = DynamicPPL.setaccs!!(vi, (acc,)) + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + pointwise_logps = map(iters) do (sample_idx, chain_idx) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Re-evaluate the model + _, vi = DynamicPPL.init!!( + model, + vi, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + DynamicPPL.getacc(vi, Val(accname)).logps + end + + # pointwise_logps is a matrix of OrderedDicts + all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + for d in pointwise_logps + union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) + end + # this is a 3D array: (iterations, variables, chains) + new_data = [ + get(pointwise_logps[iter, chain], k, missing) for + iter in 1:size(pointwise_logps, 1), k in all_keys, + chain in 1:size(pointwise_logps, 2) + ] + + if Tout == MCMCChains.Chains + return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) + elseif Tout <: AbstractDict + return Tout{DynamicPPL.VarName,Matrix{Float64}}( + k => new_data[:, i, :] for (i, k) in enumerate(all_keys) + ) + end +end + +""" + DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains + ) + +Compute the pointwise log-likelihoods of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the likelihood terms. + +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). +""" +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood)) +end + +""" + DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains + ) + +Compute the pointwise log-prior-densities of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the prior terms. + +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). +""" +function DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior)) +end + +""" + logjoint(model::Model, chain::MCMCChains.Chains) + +Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> logjoint(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -5.440428709758045 + -5.440428709758045 + -5.440428709758045 +``` +""" +function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.logjoint(model, argvals_dict) + end +end + +""" + loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> loglikelihood(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -2.1447298858494 + -2.1447298858494 + -2.1447298858494 +``` +""" +function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.loglikelihood(model, argvals_dict) + end +end + +""" + logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> logprior(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -3.2956988239086447 + -3.2956988239086447 + -3.2956988239086447 +``` +""" +function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.logprior(model, argvals_dict) + end +end + end diff --git a/src/model.jl b/src/model.jl index 6c7e8de94..d6682416b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1058,42 +1058,6 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogjoint(last(evaluate!!(model, varinfo))) end -""" - logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logjoint(demo_model([1., 2.]), chain); -``` -""" -function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logjoint(model, argvals_dict) - end -end - """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -1116,42 +1080,6 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogprior(last(evaluate!!(model, varinfo))) end -""" - logprior(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logprior(demo_model([1., 2.]), chain); -``` -""" -function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logprior(model, argvals_dict) - end -end - """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -1170,42 +1098,6 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getloglikelihood(last(evaluate!!(model, varinfo))) end -""" - loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> loglikelihood(demo_model([1., 2.]), chain); -``` -""" -function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - loglikelihood(model, argvals_dict) - end -end - # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 47ca62530..848ecb1f0 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,35 +1,22 @@ """ - PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator + PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator An accumulator that stores the log-probabilities of each variable in a model. -Internally this accumulator stores the log-probabilities in a dictionary, where -the keys are the variable names and the values are vectors of -log-probabilities. Each element in a vector corresponds to one execution of the -model. +Internally this accumulator stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are log-probabilities. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies -which log-probabilities to store in the accumulator. `KeyType` is the type by which variable -names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary -used internally to store the log-probabilities, by default -`OrderedDict{KeyType, Vector{LogProbType}}`. +which log-probabilities to store in the accumulator. """ -struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: - AbstractAccumulator - logps::D -end - -function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) -end +struct PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator + logps::OrderedDict{VarName,LogProbType} -function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,VarName}() -end - -function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} - logps = OrderedDict{KeyType,Vector{LogProbType}}() - return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) + function PointwiseLogProbAccumulator{whichlogprob}( + d::OrderedDict{VarName,LogProbType}=OrderedDict{VarName,LogProbType}() + ) where {whichlogprob} + return new{whichlogprob}(d) + end end function Base.:(==)( @@ -42,28 +29,14 @@ function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichl return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) end -function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) - logps = acc.logps - # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. - T = last(fieldtypes(eltype(logps))) - logpvec = get!(logps, vn, T()) - return push!(logpvec, logp) -end - -function Base.push!( - acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp -) where {whichlogprob} - return push!(acc, string(vn), logp) -end - function accumulator_name( ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} ) where {whichlogprob} return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) +function _zero(::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}() end reset(acc::PointwiseLogProbAccumulator) = _zero(acc) split(acc::PointwiseLogProbAccumulator) = _zero(acc) @@ -71,21 +44,14 @@ function combine( acc::PointwiseLogProbAccumulator{whichlogprob}, acc2::PointwiseLogProbAccumulator{whichlogprob}, ) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(+, acc.logps, acc2.logps)) end function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - # Note that in only accumulating LogPrior, we effectively ignore logjac - # (since we want to return log densities that don't depend on the - # linking status of the VarInfo). - subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) - push!(acc, vn, subacc.logp) + acc.logps[vn] = logpdf(right, val) end return acc end @@ -99,172 +65,11 @@ function accumulate_observe!!( return acc end if whichlogprob == :both || whichlogprob == :likelihood - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) - push!(acc, vn, subacc.logp) + acc.logps[vn] = loglikelihood(right, left) end return acc end -""" - pointwise_logdensities( - model::Model, - chain::Chains, - keytype=String, - ::Val{whichlogprob}=Val(:both), - ) - -Runs `model` on each sample in `chain` returning a `OrderedDict{VarName, Matrix{Float64}}` -with keys being model variables and values being matrices of shape -`(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported, with `VarName` being the default. -`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or -`:likelihood`. - -See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) -julia> using MCMCChains - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - y ~ Normal(m, √s) - end -demo (generic function with 2 methods) - -julia> # Example observations. - model = demo([1.0, 2.0, 3.0], [4.0]); - -julia> # A chain with 3 iterations. - chain = Chains( - reshape(1.:6., 3, 2), - [:s, :m] - ); - -julia> pointwise_logdensities(model, chain) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, String) -OrderedDict{String, Matrix{Float64}} with 6 entries: - "s" => [-0.802775; -1.38222; -2.09861;;] - "m" => [-8.91894; -7.51551; -7.46824;;] - "xs[1]" => [-5.41894; -5.26551; -5.63491;;] - "xs[2]" => [-2.91894; -3.51551; -4.13491;;] - "xs[3]" => [-1.41894; -2.26551; -2.96824;;] - "y" => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, VarName) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` -""" -function pointwise_logdensities( - model::Model, chain, ::Type{KeyType}=VarName, ::Val{whichlogprob}=Val(:both) -) where {KeyType,whichlogprob} - # Get the data by executing the model once - vi = VarInfo(model) - - # This accumulator tracks the pointwise log-probabilities in a single iteration. - AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} - vi = setaccs!!(vi, (AccType(),)) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - - # Maintain a separate accumulator that isn't tied to a VarInfo but rather - # tracks _all_ iterations. - all_logps = AccType() - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - vi = last(evaluate!!(model, vi)) - - # Get the log-probabilities - this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps - - # Merge into main acc - for (varname, this_lp) in this_iter_logps - # Because `this_lp` is obtained from one model execution, it should only - # contain one variable, hence `only()`. - push!(all_logps, varname, only(this_lp)) - end - end - - niters = size(chain, 1) - nchains = size(chain, 3) - logdensities = OrderedDict( - varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps - ) - return logdensities -end - function pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} @@ -274,38 +79,10 @@ function pointwise_logdensities( return getacc(varinfo, Val(accumulator_name(AccType))).logps end -""" - pointwise_loglikelihoods(model, chain[, keytype]) - -Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the likelihood terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=VarName) where {T} - return pointwise_logdensities(model, chain, T, Val(:likelihood)) -end - function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:likelihood)) end -""" - pointwise_prior_logdensities(model, chain[, keytype]) - -Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the prior terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). -""" -function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=VarName -) where {T} - return pointwise_logdensities(model, chain, T, Val(:prior)) -end - function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d6c0cbcad..2ba25f142 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -510,7 +510,7 @@ function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} end """ - logjoint(model::Model, θ) + logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log joint probability of variables `θ` for the probabilistic `model`. @@ -539,10 +539,11 @@ julia> # Truth. -9902.33787706641 ``` """ -logjoint(model::Model, θ) = logjoint(model, SimpleVarInfo(θ)) +logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logjoint(model, SimpleVarInfo(θ)) """ - logprior(model::Model, θ) + logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log prior probability of variables `θ` for the probabilistic `model`. @@ -571,10 +572,11 @@ julia> # Truth. -5000.918938533205 ``` """ -logprior(model::Model, θ) = logprior(model, SimpleVarInfo(θ)) +logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logprior(model, SimpleVarInfo(θ)) """ - loglikelihood(model::Model, θ) + loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log likelihood of variables `θ` for the probabilistic `model`. @@ -603,7 +605,8 @@ julia> # Truth. -4901.418938533205 ``` """ -Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) +Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = + loglikelihood(model, SimpleVarInfo(θ)) # Allow usage of `NamedBijector` too. function link!!( diff --git a/src/varinfo.jl b/src/varinfo.jl index 417766653..734bf3db5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1783,72 +1783,6 @@ function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) return missing_keys end -""" - setval!(vi::VarInfo, x) - setval!(vi::VarInfo, values, keys) - setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - -Set the values in `vi` to the provided values and leave those which are not present in -`x` or `chains` unchanged. - -## Notes -This is rather limited for two reasons: -1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood, - and therefore suffers from the same limitations as [`subsumes_string`](@ref). -2. It will set every `vn` present in `keys`. It will NOT however - set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`, - representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will - be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]` - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 -``` -""" -setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) -setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) -end - -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - set_transformed!!(vi, false, vn) - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index aac59380c..be5f20010 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,4 +1,4 @@ -@testset "logdensities_likelihoods.jl" begin +@testset "pointwise_logdensities.jl" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -39,32 +39,35 @@ end @testset "pointwise_logdensities chain" begin - # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, - # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the version on `Chains`. model = DynamicPPL.TestUtils.demo_assume_index_observe() - # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced - # an impl of this for containers. - # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. vns = DynamicPPL.TestUtils.varnames(model) # Get some random `NamedTuple` samples from the prior. num_iters = 3 vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + chain = Chains( + permutedims(vals_arr), + map(Symbol, vns); + info=(varname_to_symbol=Dict(vn => Symbol(vn) for vn in vns),), + ) # Compute the different pointwise logdensities. logjoints_pointwise = pointwise_logdensities(model, chain) logpriors_pointwise = pointwise_prior_logdensities(model, chain) loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + # Check output type + @test logjoints_pointwise isa MCMCChains.Chains + @test logpriors_pointwise isa MCMCChains.Chains + @test loglikelihoods_pointwise isa MCMCChains.Chains + # Check that they contain the correct variables. - @test all(vn in keys(logjoints_pointwise) for vn in vns) - @test all(vn in keys(logpriors_pointwise) for vn in vns) - @test !any(Base.Fix1(subsumes, @varname(x)), keys(logpriors_pointwise)) - @test !any(vn in keys(loglikelihoods_pointwise) for vn in vns) - @test all(Base.Fix1(subsumes, @varname(x)), keys(loglikelihoods_pointwise)) + @test all(Symbol(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(Symbol(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix1(startswith, "x"), String.(keys(logpriors_pointwise))) + @test !any(Symbol(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix1(startswith, "x"), String.(keys(loglikelihoods_pointwise))) # Get the sum of the logjoints for each of the iterations. logjoints = [ diff --git a/test/varinfo.jl b/test/varinfo.jl index 5b541e1dd..6b31fbe91 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -317,140 +317,13 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval!" begin - @model function testmodel(x) - n = length(x) - s ~ truncated(Normal(); lower=0) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, s^2 * I) - end - - @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} - n = length(x) - s ~ truncated(Normal(); lower=0) - - m = TV(undef, n) - for i in eachindex(m) - m[i] ~ Normal() - end - - for i in eachindex(x) - x[i] ~ Normal(m[i], s) - end - end - - x = randn(5) - model_mv = testmodel(x) - model_uv = testmodel_univariate(x) - - for model in [model_uv, model_mv] - m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) - s_vns = @varname(s) - - vi_typed = DynamicPPL.typed_varinfo(model) - vi_untyped = DynamicPPL.untyped_varinfo(model) - vi_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) - - model_name = model == model_uv ? "univariate" : "multivariate" - @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ - vi_untyped, vi_typed, vi_vnv, vi_vnv_typed - ] - Random.seed!(23) - vicopy = deepcopy(vi) - - ### `setval` ### - # TODO(mhauru) The interface here seems inconsistent between Metadata and - # VarNamedVector. I'm lazy to fix it though, because I think we need to - # rework it soon anyway. - if vi in [vi_vnv, vi_vnv_typed] - DynamicPPL.setval!(vicopy, zeros(5), m_vns) - else - DynamicPPL.setval!(vicopy, (m=zeros(5),)) - end - # Setting `m` fails for univariate due to limitations of `setval!`. - # See docstring of `setval!` for more info. - if model == model_uv && vi in [vi_untyped, vi_typed] - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] == vi[s_vns] - - # Ordering is NOT preserved => fails for multivariate model. - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] == vi[s_vns] - - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - DynamicPPL.setval!(vicopy, (s=42,)) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] == 42 - end - end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/250 - @model function demo() - return x ~ filldist(MvNormal([1, 100], I), 2) - end - - vi = VarInfo(demo()) - vals_prev = vi.metadata.x.vals - ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] - DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals - end - - @testset "setval! on chain" begin - # Define a helper function - """ - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - - Test `setval!` on `model` and `chain`. - - Worth noting that this only supports models containing symbols of the forms - `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. - """ - function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - θ_old = var_info[:] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[:] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end - end - + @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS chain = make_chain_from_prior(model, 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end end