diff --git a/HISTORY.md b/HISTORY.md index 0a673decc..0188c4fce 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,83 @@ # 0.40.0 -[...] +## Breaking changes + +**DynamicPPL 0.37** + +Turing.jl v0.40 updates DynamicPPL compatibility to 0.37. +The summary of the changes provided here is intended for end-users of Turing. +If you are a package developer, or would otherwise like to understand these changes in-depth, please see [the DynamicPPL changelog](https://github.com/TuringLang/DynamicPPL.jl/blob/main/HISTORY.md#0370). + + - **`@submodel`** is now completely removed; please use `to_submodel`. + + - **Prior and likelihood calculations** are now completely separated in Turing. Previously, the log-density used to be accumulated in a single field and thus there was no clear way to separate prior and likelihood components. + + + **`@addlogprob! f`**, where `f` is a float, now adds to the likelihood by default. + + You can instead use **`@addlogprob! (; logprior=x, loglikelihood=y)`** to control which log-density component to add to. + + This means that usage of `PriorContext` and `LikelihoodContext` is no longer needed, and these have now been removed. + - The special **`__context__`** variable has been removed. If you still need to access the evaluation context, it is now available as `__model__.context`. + +**Log-density in chains** + +When sampling from a Turing model, the resulting `MCMCChains.Chains` object now contains not only the log-joint (accessible via `chain[:lp]`) but also the log-prior and log-likelihood (`chain[:logprior]` and `chain[:loglikelihood]` respectively). + +These values now correspond to the log density of the sampled variables exactly as per the model definition / user parameterisation and thus will ignore any linking (transformation to unconstrained space). +For example, if the model is `@model f() = x ~ LogNormal()`, `chain[:lp]` would always contain the value of `logpdf(LogNormal(), x)` for each sampled value of `x`. +Previously these values could be incorrect if linking had occurred: some samplers would return `logpdf(Normal(), log(x))` i.e. the log-density with respect to the transformed distribution. + +**Gibbs sampler** + +When using Turing's Gibbs sampler, e.g. `Gibbs(:x => MH(), :y => HMC(0.1, 20))`, the conditioned variables (for example `y` during the MH step, or `x` during the HMC step) are treated as true observations. +Thus the log-density associated with them is added to the likelihood. +Previously these would effectively be added to the prior (in the sense that if `LikelihoodContext` was used they would be ignored). +This is unlikely to affect users but we mention it here to be explicit. +This change only affects the log probabilities as the Gibbs component samplers see them; the resulting chain will include the usual log prior, likelihood, and joint, as described above. + +**Particle Gibbs** + +Previously, only 'true' observations (i.e., `x ~ dist` where `x` is a model argument or conditioned upon) would trigger resampling of particles. +Specifically, there were two cases where resampling would not be triggered: + + - Calls to `@addlogprob!` + - Gibbs-conditioned variables: e.g. `y` in `Gibbs(:x => PG(20), :y => MH())` + +Turing 0.40 changes this such that both of the above cause resampling. +(The second case follows from the changes to the Gibbs sampler, see above.) + +This release also fixes a bug where, if the model ended with one of these statements, their contribution to the particle weight would be ignored, leading to incorrect results. + +The changes above also mean that certain models that previously worked with PG-within-Gibbs may now error. +Specifically this is likely to happen when the dimension of the model is variable. +For example: + +```julia +@model function f() + x ~ Bernoulli() + if x + y1 ~ Normal() + else + y1 ~ Normal() + y2 ~ Normal() + end + # (some likelihood term...) +end +sample(f(), Gibbs(:x => PG(20), (:y1, :y2) => MH()), 100) +``` + +This sampler now cannot be used for this model because depending on which branch is taken, the number of observations will be different. +To use PG-within-Gibbs, the number of observations that the PG component sampler sees must be constant. +Thus, for example, this will still work if `x`, `y1`, and `y2` are grouped together under the PG component sampler. + +If you absolutely require the old behaviour, we recommend using Turing.jl v0.39, but also thinking very carefully about what the expected behaviour of the model is, and checking that Turing is sampling from it correctly (note that the behaviour on v0.39 may in general be incorrect because of the fact that Gibbs-conditioned variables did not trigger resampling). +We would also welcome any GitHub issues highlighting such problems. +Our support for dynamic models is incomplete and is liable to undergo further changes. + +## Other changes + + - Sampling using `Prior()` should now be about twice as fast because we now avoid evaluating the model twice on every iteration. + - `Turing.Inference.Transition` now has different fields. + If `t isa Turing.Inference.Transition`, `t.stat` is always a NamedTuple, not `nothing` (if it genuinely has no information then it's an empty NamedTuple). + Furthermore, `t.lp` has now been split up into `t.logprior` and `t.loglikelihood` (see also 'Log-density in chains' section above). # 0.39.9 diff --git a/Project.toml b/Project.toml index f159147ff..b0504e367 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36.3" +DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 5718e3855..2c4bd0898 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -58,15 +58,12 @@ function DynamicPPL.initialstep( # Ensure that initial sample is in unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) # Perform initial step. @@ -76,12 +73,9 @@ function DynamicPPL.initialstep( steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) - # Create first sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state @@ -100,12 +94,9 @@ function AbstractMCMC.step( steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) - # Create next sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index ad8fdad44..0f755988e 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +56,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +72,8 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) + return _optimize(f, args...; kwargs...) end """ @@ -104,8 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -127,8 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -144,9 +140,10 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) + return _optimize(f, args...; kwargs...) end + """ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) @@ -166,7 +163,9 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype + ) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -184,7 +183,7 @@ function _optimize( vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) logdensity_optimum = Optimisation.OptimLogDensity( - f.ldf.model, vi_optimum, f.ldf.context + f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype ) vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) diff --git a/src/Turing.jl b/src/Turing.jl index 1ff231017..0cdbe2458 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -71,7 +71,6 @@ using DynamicPPL: unfix, prefix, conditioned, - @submodel, to_submodel, LogDensityFunction, @addlogprob! @@ -81,7 +80,6 @@ using OrderedCollections: OrderedDict # Turing essentials - modelling macros and inference algorithms export # DEPRECATED - @submodel, generated_quantities, # Modelling - AbstractPPL and DynamicPPL @model, diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..d6e9afcbb 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -17,7 +17,8 @@ using DynamicPPL: setindex!!, push!!, setlogp!!, - getlogp, + getlogjoint, + getlogjoint_internal, VarName, getsym, getdist, @@ -26,9 +27,6 @@ using DynamicPPL: SampleFromPrior, SampleFromUniform, DefaultContext, - PriorContext, - LikelihoodContext, - SamplingContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors @@ -125,37 +123,110 @@ end ###################### # Default Transition # ###################### -# Default -getstats(t) = nothing +getstats(::Any) = NamedTuple() +getstats(nt::NamedTuple) = nt -abstract type AbstractTransition end - -struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition +struct Transition{T,F<:AbstractFloat,N<:NamedTuple} θ::T - lp::F # TODO: merge `lp` with `stat` - stat::S -end + logprior::F + loglikelihood::F + stat::N + + """ + Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true) + + Construct a new `Turing.Inference.Transition` object using the outputs of a + sampler step. + + Here, `vi` represents a VarInfo _for which the appropriate parameters have + already been set_. However, the accumulators (e.g. logp) may in general + have junk contents. The role of this method is to re-evaluate `model` and + thus set the accumulators to the correct values. + + `stats` is any object on which `Turing.Inference.getstats` can be called to + return a NamedTuple of statistics. This could be, for example, the transition + returned by an (unwrapped) external sampler. Or alternatively, it could + simply be a NamedTuple itself (for which `getstats` acts as the identity). + + By default, the model is re-evaluated in order to obtain values of: + - the values of the parameters as per user parameterisation (`vals_as_in_model`) + - the various components of the log joint probability (`logprior`, `loglikelihood`) + that are guaranteed to be correct. + + If you **know** for a fact that the VarInfo `vi` already contains this information, + then you can set `reevaluate=false` to skip the re-evaluation step. + + !!! warning + Note that in general this is unsafe and may lead to wrong results. + + If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that + the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`, + and `LogLikelihoodAccumulator` set up with the correct values. Note that the + `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it + must be set up to track `x := y` statements. + """ + function Transition( + model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true + ) + # Avoid mutating vi as it may be used later e.g. when constructing + # sampler states. + vi = deepcopy(vi) + if reevaluate + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + ), + ) + _, vi = DynamicPPL.evaluate!!(model, vi) + end -Transition(θ, lp) = Transition(θ, lp, nothing) -function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) - θ = getparams(model, vi) - lp = getlogp(vi) - return Transition(θ, lp, getstats(t)) -end + # Extract all the information we need + vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + logprior = DynamicPPL.getlogprior(vi) + loglikelihood = DynamicPPL.getloglikelihood(vi) -function metadata(t::Transition) - stat = t.stat - if stat === nothing - return (lp=t.lp,) - else - return merge((lp=t.lp,), stat) + # Get additional statistics + stats = getstats(stats) + return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( + vals_as_in_model, logprior, loglikelihood, stats + ) end -end -DynamicPPL.getlogp(t::Transition) = t.lp + function Transition( + model::DynamicPPL.Model, + untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, + stats; + reevaluate=true, + ) + # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's + # much faster to convert it to a typed varinfo first, hence this method. + # https://github.com/TuringLang/Turing.jl/issues/2604 + return Transition( + model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate + ) + end +end -# Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) +function getstats_with_lp(t::Transition) + return merge( + t.stat, + ( + lp=t.logprior + t.loglikelihood, + logprior=t.logprior, + loglikelihood=t.loglikelihood, + ), + ) +end +function getstats_with_lp(vi::AbstractVarInfo) + return ( + lp=DynamicPPL.getlogjoint(vi), + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + ) +end ########################## # Chain making utilities # @@ -166,7 +237,7 @@ metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) # => value maps) from AbstractMCMC.getparams (defined for any sampler transition, # returns vector). """ - Turing.Inference.getparams(model::Any, t::Any) + Turing.Inference.getparams(model::DynamicPPL.Model, t::Any) Return a vector of parameter values from the given sampler transition `t` (i.e., the first return value of AbstractMCMC.step). By default, returns the `t.θ` field. @@ -175,35 +246,16 @@ the first return value of AbstractMCMC.step). By default, returns the `t.θ` fie This method only needs to be implemented for external samplers. It will be removed in future releases and replaced with `AbstractMCMC.getparams`. """ -getparams(model, t) = t.θ +getparams(::DynamicPPL.Model, t) = t.θ """ Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo) Return a key-value map of parameters from the varinfo. """ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. - # Unfortunately, using `invlink` can cause issues in scenarios where the constraints - # of the parameters change depending on the realizations. Hence we have to use - # `values_as_in_model`, which re-runs the model and extracts the parameters - # as they are seen in the model, i.e. in the constrained space. Moreover, - # this means that the code below will work both of linked and invlinked `vi`. - # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 - # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. - return DynamicPPL.values_as_in_model(model, true, deepcopy(vi)) -end -function getparams( - model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata} -) - # values_as_in_model is unconscionably slow for untyped VarInfo. It's - # much faster to convert it to a typed varinfo before calling getparams. - # https://github.com/TuringLang/Turing.jl/issues/2604 - return getparams(model, DynamicPPL.typed_varinfo(untyped_vi)) + t = Transition(model, vi, nothing) + return getparams(model, t) end -function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}}) - return Dict{VarName,Any}() -end - function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. @@ -219,7 +271,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) mapreduce(collect, vcat, iters) end - nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms @@ -234,14 +285,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) return names, vals end -function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogp(t) for t in ts], :, 1) - return [:lp], valmat -end - function get_transition_extras(ts::AbstractVector) - # Extract all metadata. - extra_data = map(metadata, ts) + # Extract stats + log probabilities from each transition or VarInfo + extra_data = map(getstats_with_lp, ts) return names_values(extra_data) end @@ -271,7 +317,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, @@ -334,7 +380,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, @@ -350,7 +396,7 @@ function AbstractMCMC.bundle_samples( vals = map(values(sym_to_vns)) do vns map(Base.Fix1(getindex, params), vns) end - return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t)) + return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t)) end end @@ -412,84 +458,4 @@ function DynamicPPL.get_matching_type( return Array{T,N} end -############## -# Utilities # -############## - -""" - - transitions_from_chain( - [rng::AbstractRNG,] - model::Model, - chain::MCMCChains.Chains; - sampler = DynamicPPL.SampleFromPrior() - ) - -Execute `model` conditioned on each sample in `chain`, and return resulting transitions. - -The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`. - -# Details - -In a bit more detail, the process is as follows: -1. For every `sample` in `chain` - 1. For every `variable` in `sample` - 1. Set `variable` in `model` to its value in `sample` - 2. Execute `model` with variables fixed as above, sampling variables NOT present - in `chain` using `SampleFromPrior` - 3. Return sampled variables and log-joint - -# Example -```julia-repl -julia> using Turing - -julia> @model function demo() - m ~ Normal(0, 1) - x ~ Normal(m, 1) - end; - -julia> m = demo(); - -julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` - -julia> transitions = Turing.Inference.transitions_from_chain(m, chain); - -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints -2-element Array{Float64,1}: - -3.6294991938628374 - -2.5697948166987845 - -julia> [first(t.θ.x) for t in transitions] # extract samples for `x` -2-element Array{Array{Float64,1},1}: - [-2.0844148956440796] - [-1.704630494695469] -``` -""" -function transitions_from_chain( - model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs... -) - return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) -end - -function transitions_from_chain( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - chain::MCMCChains.Chains; - sampler=DynamicPPL.SampleFromPrior(), -) - vi = Turing.VarInfo(model) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - transitions = map(iters) do (sample_idx, chain_idx) - # Set variables present in `chain` and mark those NOT present in chain to be resampled. - DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx) - model(rng, vi, sampler) - - # Convert `VarInfo` into `NamedTuple` and save. - Transition(model, vi) - end - - return transitions -end - end # module diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index fd4d441bd..4522875b4 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,7 +1,9 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - return DynamicPPL.check_model(model; error_on_failure=true) + # TODO(DPPL0.38/penelopeysm): use InitContext + spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) + return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) end function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) return _check_model(model) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index dfd1fc0d3..98ed20b40 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -53,22 +53,26 @@ function AbstractMCMC.step( length(initial_params) == n || throw(ArgumentError("initial parameters have to be specified for each walker")) vis = map(vis, initial_params) do vi, init + # TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!! vi = DynamicPPL.initialize_parameters!!(vi, init, model) # Update log joint probability. - last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior())) + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) + ) + last(DynamicPPL.evaluate!!(spl_model, vi)) end end # Compute initial transition and states. - transition = map(Base.Fix1(Transition, model), vis) + transition = [Transition(model, vi, nothing) for vi in vis] # TODO: Make compatible with immutable `AbstractVarInfo`. state = EmceeState( vis[1], map(vis) do vi vi = DynamicPPL.link!!(vi, model) - AMH.Transition(vi[:], getlogp(vi), false) + AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false) end, ) @@ -81,17 +85,19 @@ function AbstractMCMC.step( # Generate a log joint function. vi = state.vi densitymodel = AMH.DensityModel( - Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model, vi)) + Base.Fix1( + LogDensityProblems.logdensity, + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), + ), ) # Compute the next states. - states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)) + t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state vi = DynamicPPL.unflatten(vi, _state.params) - t = Transition(getparams(model, vi), _state.lp) - return t + return Transition(model, vi, t) end newstate = EmceeState(vi, states) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 544817348..3afd91607 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -24,56 +24,48 @@ struct ESS <: InferenceAlgorithm end # always accept in the first step function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] # define previous sampler state # (do not use cache to avoid in-place sampling from prior) - oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) + oldstate = EllipticalSliceSampling.ESSState(f, DynamicPPL.getloglikelihood(vi), nothing) # compute next state sample, state = AbstractMCMC.step( rng, - EllipticalSliceSampling.ESSModel( - ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction( - model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) - ), - ), + EllipticalSliceSampling.ESSModel(ESSPrior(model, vi), ESSLikelihood(model, vi)), EllipticalSliceSampling.ESS(), oldstate, ) # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setlogp!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end # Prior distribution of considered random variable -struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} +struct ESSPrior{M<:Model,V<:AbstractVarInfo,T} model::M - sampler::S varinfo::V μ::T - function ESSPrior{M,S,V}( - model::M, sampler::S, varinfo::V - ) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} + function ESSPrior(model::Model, varinfo::AbstractVarInfo) vns = keys(varinfo) μ = mapreduce(vcat, vns) do vn dist = getdist(varinfo, vn) @@ -81,47 +73,48 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} error("[ESS] only supports Gaussian prior distributions") DynamicPPL.tovec(mean(dist)) end - return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ) + return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ) end end -function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo) - return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo) -end - # Ensure that the prior is a Gaussian distribution (checked in the constructor) EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - sampler = p.sampler varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? + # TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model, + # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason + # why we had to use the 'del' flag before this was because + # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") end - p.model(rng, varinfo, sampler) + p.model(rng, varinfo) return varinfo[:] end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} +# Evaluate log-likelihood of proposals. We need this struct because +# EllipticalSliceSampling.jl expects a callable struct / a function as its +# likelihood. +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f) - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi -) - return DynamicPPL.tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi - ) + # Force usage of `getloglikelihood` in inner constructor + function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) + return new{typeof(ldf)}(ldf) + end end -function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) -end +(ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 1e0481b50..af31e0243 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -26,8 +26,6 @@ There are a few more optional functions which you can implement to improve the i - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`. - `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space. - -- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method. """ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: InferenceAlgorithm @@ -89,27 +87,21 @@ function externalsampler( return ExternalSampler(sampler, adtype, Val(unconstrained)) end -""" - getlogp_external(external_transition, external_state) - -Get the log probability density associated with the external sampler's -transition and state. Returns `missing` by default; in this case, an extra -model evaluation will be needed to calculate the correct log density. -""" -getlogp_external(::Any, ::Any) = missing -getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp -getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density - -struct TuringState{S,V1<:AbstractVarInfo,M,V,C} +# TODO(penelopeysm): Can't we clean this up somehow? +struct TuringState{S,V1,M,V} state::S - # Note that this varinfo has the correct parameters and logp obtained from - # the state, whereas `ldf.varinfo` will in general have junk inside it. + # Note that this varinfo must have the correct parameters set; but logp + # does not matter as it will be re-evaluated varinfo::V1 - ldf::DynamicPPL.LogDensityFunction{M,V,C} + # Note that in general the VarInfo inside this LogDensityFunction will have + # junk parameters and logp. It only exists to provide structure + ldf::DynamicPPL.LogDensityFunction{M,V} end -varinfo(state::TuringState) = state.varinfo -varinfo(state::AbstractVarInfo) = state +# get_varinfo should return something from which the correct parameters can be +# obtained, hence we use state.varinfo rather than state.ldf.varinfo +get_varinfo(state::TuringState) = state.varinfo +get_varinfo(state::AbstractVarInfo) = state getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState) @@ -119,24 +111,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params -function make_updated_varinfo( - f::DynamicPPL.LogDensityFunction, external_transition, external_state -) - # Set the parameters. - # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) - # The latter uses the state rather than the transition. - # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead - new_parameters = getparams(f.model, external_transition) - new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters) - # Set (or recalculate, if needed) the log density. - new_logp = getlogp_external(external_transition, external_state) - return if ismissing(new_logp) - last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context)) - else - DynamicPPL.setlogp!!(new_varinfo, new_logp) - end -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -163,7 +137,9 @@ function AbstractMCMC.step( end # Construct LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype + ) # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing @@ -181,13 +157,13 @@ function AbstractMCMC.step( ) end - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) + # The latter uses the state rather than the transition. + # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead + new_parameters = Turing.Inference.getparams(f.model, transition_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end @@ -206,12 +182,12 @@ function AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs... ) - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) + # The latter uses the state rather than the transition. + # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead + new_parameters = Turing.Inference.getparams(f.model, transition_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index f36cb9c36..692748767 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) # # Purpose: avoid triggering resampling of variables we're conditioning on. # - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`. # - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to # undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable # rather than only for the "true" observations. @@ -177,17 +177,21 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. - value, lp, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, lp, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( child_context, DynamicPPL.SampleFromPrior(), right, @@ -195,7 +199,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -208,14 +212,26 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, lp, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, lp, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else - value, lp, new_global_vi = DynamicPPL.tilde_assume( + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, DynamicPPL.SampleFromPrior(), @@ -224,7 +240,7 @@ function DynamicPPL.tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -327,7 +343,7 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -varinfo(state::GibbsState) = state.vi +get_varinfo(state::GibbsState) = state.vi """ Initialise a VarInfo for the Gibbs sampler. @@ -347,7 +363,7 @@ function initial_varinfo(rng, model, spl, initial_params) # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) end return vi end @@ -374,7 +390,7 @@ function AbstractMCMC.step( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -399,7 +415,7 @@ function AbstractMCMC.step_warmup( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -449,7 +465,7 @@ function gibbs_initialstep_recursive( initial_params=initial_params_local, kwargs..., ) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, get_global_varinfo(context)) @@ -477,7 +493,7 @@ function AbstractMCMC.step( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -487,7 +503,7 @@ function AbstractMCMC.step( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -497,7 +513,7 @@ function AbstractMCMC.step_warmup( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -507,7 +523,7 @@ function AbstractMCMC.step_warmup( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -525,16 +541,11 @@ function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:MH}, - state::AbstractVarInfo, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::MHState, params::AbstractVarInfo ) - # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. - # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated - # with `model.context` before hitting `model.f`. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) + # Re-evaluate to update the logprob. + new_vi = last(DynamicPPL.evaluate!!(model, params)) + return MHState(new_vi, DynamicPPL.getlogjoint_internal(new_vi)) end function setparams_varinfo!!( @@ -544,10 +555,8 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. To do this, we have to call evaluate!! with the sampler, rather - # than just a context, because ESS is peculiar in how it uses LikelihoodContext for - # some variables and DefaultContext for others. - return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) + # update its logprob. + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( @@ -557,7 +566,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params @@ -596,7 +605,7 @@ state for this sampler. This is relevant when multilple samplers are sampling th variables, and one might need it to be linked while the other doesn't. """ function match_linking!!(varinfo_local, prev_state_local, model) - prev_varinfo_local = varinfo(prev_state_local) + prev_varinfo_local = get_varinfo(prev_state_local) was_linked = DynamicPPL.istrans(prev_varinfo_local) is_linked = DynamicPPL.istrans(varinfo_local) if was_linked && !is_linked @@ -678,10 +687,10 @@ function gibbs_step_recursive( # Take a step with the local sampler. new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...)) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge the latest values for all the variables in the current sampler. new_global_vi = merge(get_global_varinfo(context), new_vi_local) - new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + new_global_vi = DynamicPPL.setlogp!!(new_global_vi, DynamicPPL.getlogp(new_vi_local)) new_states = (new_states..., new_state) return gibbs_step_recursive( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index b5f51587b..d80502f7e 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -25,7 +25,7 @@ end ### Hamiltonian Monte Carlo samplers. ### -varinfo(state::HMCState) = state.vi +get_varinfo(state::HMCState) = state.vi """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end @@ -191,14 +193,7 @@ function DynamicPPL.initialstep( metricT = getmetricT(spl.alg) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -213,9 +208,6 @@ function DynamicPPL.initialstep( end theta = vi[:] - # Cache current log density. - log_density_old = getlogp(vi) - # Find good eps if not provided one if iszero(spl.alg.ϵ) ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta) @@ -239,14 +231,13 @@ function DynamicPPL.initialstep( ) end - # Update `vi` based on acceptance - if t.stat.is_accept - vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # Update VarInfo parameters based on acceptance + new_params = if t.stat.is_accept + t.z.θ else - vi = DynamicPPL.unflatten(vi, theta) - vi = setlogp!!(vi, log_density_old) + theta end + vi = DynamicPPL.unflatten(vi, new_params) transition = Transition(model, vi, t) state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) @@ -290,7 +281,6 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) end # Compute next transition and state. @@ -303,14 +293,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -516,10 +499,6 @@ function DynamicPPL.assume( return DynamicPPL.assume(dist, vn, vi) end -function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi) - return DynamicPPL.observe(d, value, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index d83abd173..319e424fc 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -31,19 +31,19 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) - return Transition(model, vi), nothing + return Transition(model, vi, nothing), nothing end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) vi = VarInfo(rng, model, spl) - return Transition(model, vi), nothing + return Transition(model, vi, nothing), nothing end # Calculate evidence. function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) - return logsumexp(map(x -> x.lp, samples)) - log(length(samples)) + return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi) @@ -53,9 +53,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName r = rand(rng, dist) vi = push!!(vi, vn, r, dist) end - return r, 0, vi -end - -function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi) - return logpdf(dist, value), vi + vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) + return r, vi end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index fb50c5f58..863db559c 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -153,10 +153,33 @@ function MH(model::Model; proposal_type=AMH.StaticProposal) return AMH.MetropolisHastings(priors) end +""" + MHState(varinfo::AbstractVarInfo, logjoint_internal::Real) + +State for Metropolis-Hastings sampling. + +`varinfo` must have the correct parameters set inside it, but its other fields +(e.g. accumulators, which track logp) can in general be missing or incorrect. + +`logjoint_internal` is the log joint probability of the model, evaluated using +the parameters and linking status of `varinfo`. It should be equal to +`DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the +MH sampler so we store this here to avoid re-evaluating the model +unnecessarily. +""" +struct MHState{V<:AbstractVarInfo,L<:Real} + varinfo::V + logjoint_internal::L +end + +get_varinfo(s::MHState) = s.varinfo + ##################### # Utility functions # ##################### +# TODO(DPPL0.38/penelopeysm): This function should no longer be needed +# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -181,21 +204,19 @@ function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTup end end -""" - MHLogDensityFunction - -A log density function for the MH sampler. - -This variant uses the `set_namedtuple!` function to update the `VarInfo`. -""" -const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} - -function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) +# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems +# interface in that it gets evaluated with a NamedTuple. Hence we need this +# method just to deal with MH. +# TODO(DPPL0.38/penelopeysm): Check the extent to which this method is actually +# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, +# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). +# In general, we should much prefer to either (1) conform to the +# LogDensityProblems interface or (2) use VarNames anyway. +function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) - vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) - lj = getlogp(vi_new) + vi_new = last(DynamicPPL.evaluate!!(f.model, vi)) + lj = f.getlogdensity(vi_new) return lj end @@ -297,64 +318,71 @@ end # Make a proposal if we don't have a covariance proposal matrix (the default). function propose!!( - rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal + rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal ) + vi = prev_state.varinfo # Retrieve distribution and value NamedTuples. dt, vt = dist_val_tuple(spl, vi) # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(dt) - prev_trans = AMH.Transition(vt, getlogp(vi), false) + prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - # TODO: Make this compatible with immutable `VarInfo`. - # Update the values in the VarInfo. + # trans.params isa NamedTuple set_namedtuple!(vi, trans.params) - return setlogp!!(vi, trans.lp) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end # Make a proposal if we DO have a covariance proposal matrix. function propose!!( rng::AbstractRNG, - vi::AbstractVarInfo, + prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal::AdvancedMH.RandomWalkProposal, ) + vi = prev_state.varinfo # If this is the case, we can just draw directly from the proposal # matrix. vals = vi[:] # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) - prev_trans = AMH.Transition(vals, getlogp(vi), false) + prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp) + # trans.params isa AbstractVector + vi = DynamicPPL.unflatten(vi, trans.params) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end function DynamicPPL.initialstep( @@ -368,18 +396,18 @@ function DynamicPPL.initialstep( # just link everything before sampling. vi = maybe_link!!(vi, spl, spl.alg.proposals, model) - return Transition(model, vi), vi + return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs... ) # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space - vi = propose!!(rng, vi, model, spl, spl.alg.proposals) + new_state = propose!!(rng, state, model, spl, spl.alg.proposals) - return Transition(model, vi), vi + return Transition(model, new_state.varinfo, nothing), new_state end #### @@ -392,7 +420,3 @@ function DynamicPPL.assume( retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) return retval end - -function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) - return DynamicPPL.observe(SampleFromPrior(), d, value, vi) -end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index a81f436c8..ab2add975 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -4,6 +4,38 @@ ### AdvancedPS models and interface +""" + set_all_del!(vi::AbstractVarInfo) + +Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for +resampling. +""" +function set_all_del!(vi::AbstractVarInfo) + # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we + # could either: + # - keep a boolean 'resample' flag on the trace, or + # - modify the model context appropriately. + # However, this refactoring will have to wait until InitContext is + # merged into DPPL. + for vn in keys(vi) + DynamicPPL.set_flag!(vi, vn, "del") + end + return nothing +end + +""" + unset_all_del!(vi::AbstractVarInfo) + +Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing +them from being resampled. +""" +function unset_all_del!(vi::AbstractVarInfo) + for vn in keys(vi) + DynamicPPL.unset_flag!(vi, vn, "del") + end + return nothing +end + struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M @@ -18,43 +50,44 @@ function TracedModel( varinfo::AbstractVarInfo, rng::Random.AbstractRNG, ) - context = SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) + spl_model = DynamicPPL.contextualize(model, spl_context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) if kwargs !== nothing && !isempty(kwargs) error( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end - evaluator = (model.f, args...) - return TracedModel(model, sampler, varinfo, evaluator) + evaluator = (spl_model.f, args...) + return TracedModel(spl_model, sampler, varinfo, evaluator) end function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - DynamicPPL.increment_num_produce!(trace.model.f.varinfo) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) - if score === nothing - return nothing - else - return score + DynamicPPL.getlogp(trace.model.f.varinfo) - end + return score end function AdvancedPS.delete_retained!(trace::TracedModel) - DynamicPPL.set_retained_vns_del!(trace.varinfo) + # This method is called if, during a CSMC update, we perform a resampling + # and choose the reference particle as the trajectory to carry on from. + # In such a case, we need to ensure that when we continue sampling (i.e. + # the next time we hit tilde_assume), we don't use the values in the + # reference particle but rather sample new values. + # + # Here, we indiscriminately set the 'del' flag for all variables in the + # VarInfo. This is slightly overkill: it is not necessary to set the 'del' + # flag for variables that were already sampled. However, it allows us to + # avoid keeping track of which variables were sampled, which leads to many + # simplifications in the VarInfo data structure. + set_all_del!(trace.varinfo) return trace end function AdvancedPS.reset_model(trace::TracedModel) - DynamicPPL.reset_num_produce!(trace.varinfo) - return trace -end - -function AdvancedPS.reset_logprob!(trace::TracedModel) - DynamicPPL.resetlogp!!(trace.model.varinfo) return trace end @@ -102,29 +135,6 @@ function SMC(threshold::Real) return SMC(AdvancedPS.resample_systematic, threshold) end -struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The weight of the particle the sample was retrieved from." - weight::F -end - -function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) - theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) - - return SMCTransition(theta, lp, weight) -end - -metadata(t::SMCTransition) = (lp=t.lp, weight=t.weight) - -DynamicPPL.getlogp(t::SMCTransition) = t.lp - struct SMCState{P,F<:AbstractFloat} particles::P particleindex::Int @@ -182,10 +192,9 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) - DynamicPPL.empty!!(vi) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) + set_all_del!(vi) + vi = DynamicPPL.empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( @@ -202,7 +211,8 @@ function DynamicPPL.initialstep( weight = AdvancedPS.getweight(particles, 1) # Compute the first transition and the first state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) state = SMCState(particles, 2, logevidence) return transition, state @@ -220,7 +230,8 @@ function AbstractMCMC.step( weight = AdvancedPS.getweight(particles, index) # Compute the transition and the next state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=state.average_logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) nextstate = SMCState(state.particles, index + 1, state.average_logevidence) return transition, nextstate @@ -274,38 +285,28 @@ Equivalent to [`PG`](@ref). """ const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The log evidence of the sample." - logevidence::F -end - struct PGState vi::AbstractVarInfo rng::Random.AbstractRNG end -varinfo(state::PGState) = state.vi - -function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) - theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) +get_varinfo(state::PGState) = state.vi - return PGTransition(theta, lp, logevidence) -end - -metadata(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) - -DynamicPPL.getlogp(t::PGTransition) = t.lp - -function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState) - return mean(x.logevidence for x in samples) +function getlogevidence( + transitions::AbstractVector{<:Turing.Inference.Transition}, + sampler::Sampler{<:PG}, + state::PGState, +) + logevidences = map(transitions) do t + if haskey(t.stat, :logevidence) + return t.stat.logevidence + else + # This should not really happen, but if it does we can handle it + # gracefully + return missing + end + end + return mean(logevidences) end function DynamicPPL.initialstep( @@ -315,10 +316,9 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) + set_all_del!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -338,7 +338,7 @@ function DynamicPPL.initialstep( # Compute the first transition. _vi = reference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, reference.rng) end @@ -348,14 +348,14 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create reference particle for which the samples will be retained. + unset_all_del!(vi) reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) # For all other particles, do not retain the variables but resample them. - DynamicPPL.set_retained_vns_del!(vi) + set_all_del!(vi) # Create a new set of particles. num_particles = spl.alg.nparticles @@ -378,83 +378,124 @@ function AbstractMCMC.step( # Compute the transition. _vi = newreference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, newreference.rng) end function DynamicPPL.use_threadsafe_eval( - ::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo + ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo ) return false end -function trace_local_varinfo_maybe(varinfo) - try - trace = Libtask.get_taped_globals(Any).other - return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo +""" + get_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Get the `Trace` local varinfo if one exists. + +If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) + trace = try + Libtask.get_taped_globals(Any).other catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return varinfo - else - rethrow(e) - end + e == KeyError(:task_variable) ? nothing : rethrow(e) end + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo end -function trace_local_rng_maybe(rng::Random.AbstractRNG) - try - return Libtask.get_taped_globals(Any).rng +""" + get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) + +Get the `Trace` local rng if one exists. + +If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_rng_maybe(rng::Random.AbstractRNG) + return try + Libtask.get_taped_globals(Any).rng catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return rng - else - rethrow(e) - end + e == KeyError(:task_variable) ? rng : rethrow(e) + end +end + +""" + set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`. + +If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the +task. Otherwise do nothing. +""" +function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + # TODO(mhauru) This should be done in a try-catch block, as in the commented out code. + # However, Libtask currently can't handle this block. + trace = #try + Libtask.get_taped_globals(Any).other + # catch e + # e == KeyError(:task_variable) ? nothing : rethrow(e) + # end + if trace !== nothing + model = trace.model + model = Accessors.@set model.f.varinfo = vi + trace.model = model end + return nothing end function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo + rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - vi = trace_local_varinfo_maybe(_vi) - trng = trace_local_rng_maybe(rng) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + trng = get_trace_local_rng_maybe(rng) if ~haskey(vi, vn) r = rand(trng, dist) - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent + # TODO(mhauru): + # The below is the only line that differs from assume called on SampleFromPrior. + # Could we just call assume on SampleFromPrior with a specific rng? r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? - lp = 0 - return r, lp, vi -end -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. - return logpdf(dist, value), trace_local_varinfo_maybe(vi) -end + vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) -function DynamicPPL.acclogp!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -) - varinfo_trace = trace_local_varinfo_maybe(varinfo) - return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end + return r, vi end -function DynamicPPL.acclogp_observe!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp +function DynamicPPL.tilde_observe!!( + ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi ) - Libtask.produce(logp) - return trace_local_varinfo_maybe(varinfo) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end + return left, vi end # Convenient constructor @@ -465,20 +506,74 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - DynamicPPL.reset_num_produce!(newvarinfo) - tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) return newtrace end +""" + ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value. + +# Fields +$(TYPEDFIELDS) +""" +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} + "the scalar log likelihood value" + logp::T +end + +# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two +# can be used in a given VarInfo. +DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood +DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp + +function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val) + # The below line is the only difference from `LogLikelihoodAccumulator`. + Libtask.produce(val) + return ProduceLogLikelihoodAccumulator(acc1.logp + val) +end + +function DynamicPPL.accumulate_assume!!( + acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right +) + return acc +end +function DynamicPPL.accumulate_observe!!( + acc::ProduceLogLikelihoodAccumulator, right, left, vn +) + return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left)) +end + # We need to tell Libtask which calls may have `produce` calls within them. In practice most # of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with -# `acclogp_observe!!` which is what calls `produce` and go up the call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the +# call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{ + <:Tuple{ + typeof(Base.:+), + ProduceLogLikelihoodAccumulator, + DynamicPPL.LogLikelihoodAccumulator, + }, + }, +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} +) + return true +end Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? +# That's the only thing that makes tilde_assume calls result in tilde_observe calls. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c7a5cc573..2ead40ced 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,14 +12,19 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - vi = last( - DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), + # TODO(DPPL0.38/penelopeysm): replace with init!! + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) + ) + vi = VarInfo() + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), ), ) - return vi, nothing + _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + return Transition(model, vi, nothing; reevaluate=false), nothing end - -DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 0c322244e..34d7cf9d8 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -58,19 +58,15 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) end # Compute initial sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGHMCState(ℓ, vi, zero(vi[:])) @@ -98,12 +94,11 @@ function AbstractMCMC.step( α = spl.alg.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) # Compute next sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) newstate = SGHMCState(ℓ, vi, newv) return sample, newstate @@ -189,25 +184,6 @@ function SGLD(; return SGLD(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample." - lp::F - "The stepsize that was used to obtain the sample." - stepsize::F -end - -function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) - theta = getparams(model, vi) - lp = getlogp(vi) - return SGLDTransition(theta, lp, stepsize) -end - -metadata(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) - -DynamicPPL.getlogp(t::SGLDTransition) = t.lp - struct SGLDState{L,V<:AbstractVarInfo} logdensity::L vi::V @@ -221,23 +197,19 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) end # Create first sample and state. - sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0))) + transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGLDState(ℓ, vi, 1) - return sample, state + return transition, state end function AbstractMCMC.step( @@ -252,13 +224,12 @@ function AbstractMCMC.step( stepsize = spl.alg.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) # Compute next sample and state. - sample = SGLDTransition(model, vi, stepsize) + transition = Transition(model, vi, (; SGLD_stepsize=stepsize)) newstate = SGLDState(ℓ, vi, state.step + 1) - return sample, newstate + return transition, newstate end diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index fedc2510d..19c52c381 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,75 +43,24 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end -""" - OptimizationContext{C<:AbstractContext} <: AbstractContext - -The `OptimizationContext` transforms variables to their constrained space, but -does not use the density with respect to the transformation. This context is -intended to allow an optimizer to sample in R^n freely. -""" -struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - context::C - - function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !( - context isa Union{ - DynamicPPL.DefaultContext, - DynamicPPL.LikelihoodContext, - DynamicPPL.PriorContext, - } - ) - msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, - and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` - """ - throw(ArgumentError(msg)) - end - return new{C}(context) - end -end - -OptimizationContext(ctx::DynamicPPL.AbstractContext) = OptimizationContext{typeof(ctx)}(ctx) - -DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() - -function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) - r = vi[vn, dist] - lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} - # MAP - Distributions.logpdf(dist, r) - else - # MLE - 0 - end - return r, lp, vi -end - -function DynamicPPL.tilde_observe( - ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args... -) - return DynamicPPL.tilde_observe(ctx.context, args...) -end - """ OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType + F<:Function, + V<:DynamicPPL.AbstractVarInfo, + AD<:ADTypes.AbstractADType, } A struct that wraps a single LogDensityFunction. Can be invoked either using ```julia -OptimLogDensity(model, varinfo, ctx; adtype=adtype) +OptimLogDensity(model, varinfo; adtype=adtype) ``` or ```julia -OptimLogDensity(model, ctx; adtype=adtype) +OptimLogDensity(model; adtype=adtype) ``` If not specified, `adtype` defaults to `AutoForwardDiff()`. @@ -129,37 +78,35 @@ the underlying LogDensityFunction at the point `z`. This is done to satisfy the Optim.jl interface. ```julia -optim_ld = OptimLogDensity(model, varinfo, ctx) +optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ -struct OptimLogDensity{ - M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType, -} - ldf::DynamicPPL.LogDensityFunction{M,V,C,AD} -end +struct OptimLogDensity{L<:DynamicPPL.LogDensityFunction} + ldf::L -function OptimLogDensity( - model::DynamicPPL.Model, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi, ctx; adtype=adtype)) -end - -# No varinfo -function OptimLogDensity( - model::DynamicPPL.Model, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype) + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.AbstractVarInfo; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + ) + ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + return new{typeof(ldf)}(ldf) + end + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) + # No varinfo + return OptimLogDensity( + model, + getlogdensity, + DynamicPPL.ldf_default_varinfo(model, getlogdensity); + adtype=adtype, + ) + end end """ @@ -325,10 +272,13 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. - linked = DynamicPPL.istrans(m.f.ldf.varinfo) + old_ldf = m.f.ldf + linked = DynamicPPL.istrans(old_ldf.varinfo) if linked - new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) + new_f = OptimLogDensity( + old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -339,8 +289,11 @@ function StatsBase.informationmatrix( # Link it back if we invlinked it. if linked - new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + invlinked_ldf = m.f.ldf + new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) + new_f = OptimLogDensity( + invlinked_ldf.model, old_ldf.getlogdensity, new_vi; adtype=invlinked_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -553,7 +506,12 @@ function estimate_mode( ub=nothing, kwargs..., ) - check_model && DynamicPPL.check_model(model; error_on_failure=true) + if check_model + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true) + end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) initial_params = generate_initial_params(model, initial_params, constraints) @@ -563,19 +521,17 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - inner_context = if estimator isa MAP - DynamicPPL.DefaultContext() - else - DynamicPPL.LikelihoodContext() - end - ctx = OptimizationContext(inner_context) + # Note that we use `getlogjoint` rather than `getlogjoint_internal`: this + # is intentional, because even though the VarInfo may be linked, the + # optimisation target should not take the Jacobian term into account. + getlogdensity = estimator isa MAP ? DynamicPPL.getlogjoint : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the # varinfo are completely ignored. The parameters only matter if you are calling evaluate!! # directly on the fields of the LogDensityFunction - vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.ldf_default_varinfo(model, getlogdensity) vi = DynamicPPL.unflatten(vi, initial_params) # Link the varinfo if needed. @@ -588,7 +544,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi, ctx) + log_density = OptimLogDensity(model, getlogdensity, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9428af11..d51631968 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -17,12 +17,6 @@ export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") -function make_logdensity(model::DynamicPPL.Model) - weight = 1.0 - ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) -end - """ q_initialize_scale( [rng::Random.AbstractRNG,] @@ -68,7 +62,7 @@ function q_initialize_scale( num_max_trials::Int=10, reduce_factor::Real=one(eltype(scale)) / 2, ) - prob = make_logdensity(model) + prob = LogDensityFunction(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) varinfo = DynamicPPL.VarInfo(model) @@ -309,7 +303,7 @@ function vi( ) return AdvancedVI.optimize( rng, - make_logdensity(model), + LogDensityFunction(model), objective, q, n_iterations; diff --git a/test/Project.toml b/test/Project.toml index 46817c6c5..b10be0140 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36.12" +DynamicPPL = "0.37" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5..dcfe4ef46 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -155,35 +155,33 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi ) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) check_adtype(context, vi) - return logp, vi + return left, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!( DynamicPPL.childcontext(context), sampler, right, left, vi ) check_adtype(context, vi) - return logp, vi + return left, vi end """ @@ -239,32 +237,8 @@ end end end -@testset verbose = true "AD / SamplingContext" begin - # AD tests for gradient-based samplers need to be run with SamplingContext - # because samplers can potentially use this to define custom behaviour in - # the tilde-pipeline and thus change the code executed during model - # evaluation. - @testset "adtype=$adtype" for adtype in ADTYPES - @testset "alg=$alg" for alg in [ - HMC(0.1, 10; adtype=adtype), - HMCDA(0.8, 0.75; adtype=adtype), - NUTS(1000, 0.8; adtype=adtype), - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any - end - end - end -end - @testset verbose = true "AD / GibbsContext" begin - # Gibbs sampling also needs extra AD testing because the models are + # Gibbs sampling needs some extra AD testing because the models are # executed with GibbsContext and a subsetted varinfo. (see e.g. # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in # src/mcmc/gibbs.jl -- the code here mimics what happens in those @@ -283,8 +257,7 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + @test run_ad(model, adtype; test=true, benchmark=false) isa Any end end end diff --git a/test/essential/container.jl b/test/essential/container.jl index cbd7a6fe2..124637aab 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -19,6 +19,7 @@ using Turing @testset "constructor" begin vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) @@ -27,14 +28,11 @@ using Turing @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 @test res ≈ -log(2) # Catch broken copy, espetially for RNG / VarInfo newtrace = AdvancedPS.fork(trace) res2 = AdvancedPS.advance!(trace) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 2 - @test DynamicPPL.get_num_produce(newtrace.model.f.varinfo) == 1 end @testset "fork" begin @@ -46,6 +44,7 @@ using Turing return a, b end vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = normal() diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index c3ac571cb..2cc7e4bc0 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -5,7 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL -using DynamicPPL: Sampler, getlogp +using DynamicPPL: Sampler import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -113,44 +113,12 @@ using Turing check_gdemo(chn3_contd) end - @testset "Contexts" begin - # Test LikelihoodContext - @model function testmodel1(x) - a ~ Beta() - lp1 = getlogp(__varinfo__) - x[1] ~ Bernoulli(a) - return global loglike = getlogp(__varinfo__) - lp1 - end - model = testmodel1([1.0]) - varinfo = DynamicPPL.VarInfo(model) - model(varinfo, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - @test getlogp(varinfo) == loglike - - # Test MiniBatchContext - @model function testmodel2(x) - a ~ Beta() - return x[1] ~ Bernoulli(a) - end - model = testmodel2([1.0]) - varinfo1 = DynamicPPL.VarInfo(model) - varinfo2 = deepcopy(varinfo1) - model(varinfo1, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - model( - varinfo2, - DynamicPPL.SampleFromPrior(), - DynamicPPL.MiniBatchContext(DynamicPPL.LikelihoodContext(), 10), - ) - @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) - end - @testset "Prior" begin N = 10_000 - # Note that all chains contain 3 values per sample: 2 variables + log probability @testset "Single-threaded vanilla" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), N) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 1) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -158,7 +126,6 @@ using Turing @testset "Multi-threaded" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), MCMCThreads(), N, 4) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 4) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -169,25 +136,34 @@ using Turing ) @test chains isa Vector{<:NamedTuple} @test length(chains) == N - @test all(length(x) == 3 for x in chains) @test all(haskey(x, :lp) for x in chains) + @test all(haskey(x, :logprior) for x in chains) + @test all(haskey(x, :loglikelihood) for x in chains) @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end - @testset "#2169" begin - # Not exactly the same as the issue, but similar. - @model function issue2169_model() - if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext - x ~ Normal(0, 1) - else - x ~ Normal(1000, 1) - end + @testset "accumulators are set correctly" begin + # Prior() uses `reevaluate=false` when constructing a + # `Turing.Inference.Transition`, so we had better make sure that it + # does capture colon-eq statements, as we can't rely on the default + # `Transition` constructor to do this for us. + @model function coloneq() + x ~ Normal() + 10.0 ~ Normal(x) + z := 1.0 + return nothing end - - model = issue2169_model() - chain = sample(StableRNG(seed), model, Prior(), 10) - @test all(mean(chain[:x]) .< 5) + chain = sample(coloneq(), Prior(), N) + @test chain isa MCMCChains.Chains + @test all(x -> x == 1.0, chain[:z]) + # And for the same reason we should also make sure that the logp + # components are correctly calculated. + @test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x])) + @test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0)) + @test isapprox(chain[:lp], chain[:logprior] .+ chain[:loglikelihood]) + # And that the outcome is not influenced by the likelihood + @test mean(chain, :x) ≈ 0.0 atol = 0.1 end end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index e918b3a51..1e1be9b45 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -54,7 +54,7 @@ using Turing @testset "gdemo with CSMC + ESS" begin alg = Gibbs(:s => CSMC(15), :m => ESS()) - chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) + chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 54a62f434..38b9b0660 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -98,7 +98,9 @@ function initialize_nuts(model::DynamicPPL.Model) linked_vi = DynamicPPL.link!!(vi, model) # Create a LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, linked_vi; adtype=Turing.DEFAULT_ADTYPE) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, linked_vi; adtype=Turing.DEFAULT_ADTYPE + ) # Choose parameter dimensionality and initial parameter value D = LogDensityProblems.dimension(f) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..dc8cd42d0 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -207,8 +207,8 @@ end val ~ Normal(s, 1) 1.0 ~ Normal(s + m, 1) - n := m + 1 - xs = M(undef, n) + n := m + xs = M(undef, 5) for i in eachindex(xs) xs[i] ~ Beta(0.5, 0.5) end @@ -295,7 +295,7 @@ end vi::T end - Turing.Inference.varinfo(state::VarInfoState) = state.vi + Turing.Inference.get_varinfo(state::VarInfoState) = state.vi function Turing.Inference.setparams_varinfo!!( ::DynamicPPL.Model, ::DynamicPPL.Sampler, @@ -312,8 +312,8 @@ end kwargs..., ) spl.alg.non_warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step_warmup( @@ -323,30 +323,30 @@ end kwargs..., ) spl.alg.warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.non_warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end @model f() = x ~ Normal() @@ -565,40 +565,98 @@ end end end - # The below test used to sample incorrectly before - # https://github.com/TuringLang/Turing.jl/pull/2328 - @testset "dynamic model with ESS" begin - @model function dynamic_model_for_ess() - b ~ Bernoulli() - x_length = b ? 1 : 2 - x = Vector{Float64}(undef, x_length) - for i in 1:x_length - x[i] ~ Normal(i, 1.0) + @testset "PG with variable number of observations" begin + # When sampling from a model with Particle Gibbs, it is mandatory for + # the number of observations to be the same in all particles, since the + # observations trigger particle resampling. + # + # Up until Turing v0.39, `x ~ dist` statements where `x` was the + # responsibility of a different (non-PG) Gibbs subsampler used to not + # count as an observation. Instead, the log-probability `logpdf(dist, x)` + # would be manually added to the VarInfo's `logp` field and included in the + # weighting for the _following_ observation. + # + # In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!! + # which thus triggers resampling. Thus, for example, the following model + # does not work any more: + # + # @model function f() + # a ~ Poisson(2.0) + # x = Vector{Float64}(undef, a) + # for i in eachindex(x) + # x[i] ~ Normal() + # end + # end + # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) + # + # because the number of observations in each particle depends on the value + # of `a`. + # + # This testset checks that ways of working around such a situation. + + function test_dynamic_bernoulli(chain) + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) + for vn in keys(means) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.15) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.15) end end - m = dynamic_model_for_ess() - chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) - means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) - stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) - for vn in keys(means) - @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) - @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + # TODO(DPPL0.37/penelopeysm): decide what to do with these tests + @testset "Coalescing multiple observations into one" begin + # Instead of observing x[1] and x[2] separately, we lump them into a + # single distribution. + @model function dynamic_bernoulli() + b ~ Bernoulli() + if b + dists = [Normal(1.0)] + else + dists = [Normal(1.0), Normal(2.0)] + end + return x ~ product_distribution(dists) + end + model = dynamic_bernoulli() + # This currently fails because if the global varinfo has `x` with length 2, + # and the particle sampler has `b = true`, it attempts to calculate the + # log-likelihood of a length-2 vector with respect to a length-1 + # distribution. + @test_throws DimensionMismatch chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(10), :x => ESS()), + 2000; + discard_initial=100, + ) + # test_dynamic_bernoulli(chain) end - end - @testset "dynamic model with dot tilde" begin - @model function dynamic_model_with_dot_tilde( - num_zs=10, (::Type{M})=Vector{Float64} - ) where {M} - z = Vector{Int}(undef, num_zs) - z .~ Poisson(1.0) - num_ms = sum(z) - m = M(undef, num_ms) - return m .~ Normal(1.0, 1.0) - end - model = dynamic_model_with_dot_tilde() - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100) + @testset "Inserting @addlogprob!" begin + # On top of observing x[i], we also add in extra 'observations' + @model function dynamic_bernoulli_2() + b ~ Bernoulli() + x_length = b ? 1 : 2 + x = Vector{Float64}(undef, x_length) + for i in 1:x_length + x[i] ~ Normal(i, 1.0) + end + if length(x) == 1 + # This value is the expectation value of logpdf(Normal(), x) where x ~ Normal(). + # See discussion in + # https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817 + @addlogprob!(-1.418849) + end + end + model = dynamic_bernoulli_2() + chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(20), :x => ESS()), + 2000; + discard_initial=100, + ) + test_dynamic_bernoulli(chain) + end end @testset "Demo model" begin @@ -828,7 +886,9 @@ end function check_logp_correct(sampler) @testset "logp is set correctly" begin @model logp_check() = x ~ Normal() - chn = sample(logp_check(), Gibbs(@varname(x) => sampler), 100) + chn = sample( + logp_check(), Gibbs(@varname(x) => sampler), 100; progress=false + ) @test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp]) end end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 8832e5fe7..839dffbbe 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,23 +171,6 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end - @testset "prior" begin - # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance - # which means that it's _very_ difficult to find a good tolerance in the test below:) - prior_dist = truncated(Normal(3, 1); lower=0) - - @model function demo_hmc_prior() - s ~ prior_dist - return m ~ Normal(0, sqrt(s)) - end - alg = NUTS(1000, 0.8) - gdemo_default_prior = DynamicPPL.contextualize( - demo_hmc_prior(), DynamicPPL.PriorContext() - ) - chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2) - end - @testset "warning for difficult init params" begin attempt = 0 @model function demo_warn_initial_params() diff --git a/test/mcmc/is.jl b/test/mcmc/is.jl index 44fbe9201..2811e9c86 100644 --- a/test/mcmc/is.jl +++ b/test/mcmc/is.jl @@ -47,11 +47,11 @@ using Turing Random.seed!(seed) chain = sample(model, alg, n; check_model=false) - sampled = get(chain, [:a, :b, :lp]) + sampled = get(chain, [:a, :b, :loglikelihood]) @test vec(sampled.a) == ref.as @test vec(sampled.b) == ref.bs - @test vec(sampled.lp) == ref.logps + @test vec(sampled.loglikelihood) == ref.logps @test chain.logevidence == ref.logevidence end diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index add2e7404..70810e164 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -262,24 +262,6 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end - @testset "prior" begin - alg = MH() - gdemo_default_prior = DynamicPPL.contextualize( - gdemo_default, DynamicPPL.PriorContext() - ) - burnin = 10_000 - n = 10_000 - chain = sample( - StableRNG(seed), - gdemo_default_prior, - alg, - n; - discard_initial=burnin, - thinning=10, - ) - check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0]; atol=0.3) - end - @testset "`filldist` proposal (issue #2180)" begin @model demo_filldist_issue2180() = x ~ MvNormal(zeros(3), I) chain = sample( diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 7a2f5fe1c..ad7373b85 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -1,10 +1,11 @@ module ParticleMCMCTests using ..Models: gdemo_default -#using ..Models: MoGtest, MoGtest_default +using ..SamplerTestUtils: test_chain_logp_metadata using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample using Random: Random +using StableRNGs: StableRNG using Test: @test, @test_throws, @testset using Turing @@ -49,6 +50,10 @@ using Turing @test_throws ErrorException sample(fail_smc(), SMC(), 100) end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SMC()) + end + @testset "logevidence" begin Random.seed!(100) @@ -65,7 +70,10 @@ using Turing chains_smc = sample(test(), SMC(), 100) @test all(isone, chains_smc[:x]) + # the chain itself has a logevidence field @test chains_smc.logevidence ≈ -2 * log(2) + # but each transition also contains the logevidence + @test chains_smc[:logevidence] ≈ fill(chains_smc.logevidence, 100) end end @@ -88,6 +96,10 @@ end @test s.resampler === resample_systematic end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(PG(10)) + end + @testset "logevidence" begin Random.seed!(100) @@ -105,6 +117,7 @@ end @test all(isone, chains_pg[:x]) @test chains_pg.logevidence ≈ -2 * log(2) atol = 0.01 + @test chains_pg[:logevidence] ≈ fill(chains_pg.logevidence, 100) end # https://github.com/TuringLang/Turing.jl/issues/1598 @@ -114,6 +127,24 @@ end @test length(unique(c[:s])) == 1 end + @testset "addlogprob leads to reweighting" begin + # Make sure that PG takes @addlogprob! into account. It didn't use to: + # https://github.com/TuringLang/Turing.jl/issues/1996 + @model function addlogprob_demo() + x ~ Normal(0, 1) + if x < 0 + @addlogprob! -10.0 + else + # Need a balanced number of addlogprobs in all branches, or + # else PG will error + @addlogprob! 0.0 + end + end + c = sample(StableRNG(468), addlogprob_demo(), PG(10), 100) + # Result should be biased towards x > 0. + @test mean(c[:x]) > 0.7 + end + # https://github.com/TuringLang/Turing.jl/issues/2007 @testset "keyword arguments not supported" begin @model kwarg_demo(; x=2) = return x diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 1671362ed..ee943270c 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,6 +2,7 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo +using ..SamplerTestUtils: test_chain_logp_metadata using DynamicPPL.TestUtils.AD: run_ad using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL: DynamicPPL @@ -32,6 +33,10 @@ using Turing chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGHMC(; learning_rate=0.02, momentum_decay=0.5)) + end end @testset "Testing sgld.jl" begin @@ -46,6 +51,7 @@ end sampler = DynamicPPL.Sampler(alg) @test sampler isa DynamicPPL.Sampler{<:SGLD} end + @testset "sgld inference" begin rng = StableRNG(1) @@ -59,6 +65,10 @@ end @test s_weighted ≈ 49 / 24 atol = 0.2 @test m_weighted ≈ 7 / 6 atol = 0.2 end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGLD(; stepsize=PolynomialStepsize(0.25))) + end end end diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 2acb7edc5..6b93e7629 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -24,28 +24,7 @@ using Turing hasstats(result) = result.optim_result.stats !== nothing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 - @testset "OptimizationContext" begin - # Used for testing how well it works with nested contexts. - struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext - context::C - logprior_weight::T1 - loglikelihood_weight::T2 - end - DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent() - DynamicPPL.childcontext(parent::OverrideContext) = parent.context - DynamicPPL.setchildcontext(parent::OverrideContext, child) = - OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight) - - # Only implement what we need for the models above. - function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, context.logprior_weight, vi - end - function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return context.loglikelihood_weight, vi - end - + @testset "OptimLogDensity and contexts" begin @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -62,48 +41,36 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + # Doesn't matter if we use getlogjoint or getlogjoint_internal since the + # VarInfo isn't linked. + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) + @test ld1(w) == ld2(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) + @test ld1(w) == ld2(w) end - @testset "Weighted" begin - function override(model) - return DynamicPPL.contextualize( - model, OverrideContext(model.context, 100, 1) - ) - end - m1 = override(model1(x)) - m2 = override(model2() | (x=x,)) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) - end - - @testset "Default, Likelihood, Prior Contexts" begin + @testset "Joint, prior, and likelihood" begin m1 = model1(x) - defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) a = [0.3] + ld_joint = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld_prior = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior) + ld_likelihood = Turing.Optimisation.OptimLogDensity( + m1, DynamicPPL.getloglikelihood + ) + @test ld_joint(a) == ld_prior(a) + ld_likelihood(a) - @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == - Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + - Turing.Optimisation.OptimLogDensity(m1, prictx)(a) - - # test that PriorContext is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ + # test that the prior accumulator is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -651,8 +618,7 @@ using Turing return nothing end m = saddle_model() - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx) + optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m) diff --git a/test/runtests.jl b/test/runtests.jl index 9fec2f737..5fb6b2141 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ seed!(23) include("test_utils/models.jl") include("test_utils/numerical_tests.jl") +include("test_utils/sampler.jl") Turing.setprogress!(false) included_paths, excluded_paths = parse_args(ARGS) diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl new file mode 100644 index 000000000..32a3647f9 --- /dev/null +++ b/test/test_utils/sampler.jl @@ -0,0 +1,27 @@ +module SamplerTestUtils + +using Turing +using Test + +""" +Check that when sampling with `spl`, the resulting chain contains log-density +metadata that is correct. +""" +function test_chain_logp_metadata(spl) + @model function f() + # some prior term (but importantly, one that is constrained, i.e., can + # be linked with non-identity transform) + x ~ LogNormal() + # some likelihood term + return 1.0 ~ Normal(x) + end + chn = sample(f(), spl, 100) + # Check that the log-prior term is calculated in unlinked space. + @test chn[:logprior] ≈ logpdf.(LogNormal(), chn[:x]) + @test chn[:loglikelihood] ≈ logpdf.(Normal.(chn[:x]), 1.0) + # This should always be true, but it also indirectly checks that the + # log-joint is also calculated in unlinked space. + @test chn[:lp] ≈ chn[:logprior] + chn[:loglikelihood] +end + +end