diff --git a/HISTORY.md b/HISTORY.md index 428766b79..d11a41681 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,12 @@ +# 0.40.0 + +TODO + + - DynamicPPL 0.37 stuff + + - pMCMC and Gibbs? + - Prior is faster + # 0.39.9 Revert a bug introduced in 0.39.5 in the external sampler interface. diff --git a/Project.toml b/Project.toml index 7e952e7c0..05c5f6380 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.39.9" +version = "0.40.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36.3" +DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" @@ -90,3 +90,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 5718e3855..2c4bd0898 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -58,15 +58,12 @@ function DynamicPPL.initialstep( # Ensure that initial sample is in unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) # Perform initial step. @@ -76,12 +73,9 @@ function DynamicPPL.initialstep( steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) - # Create first sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state @@ -100,12 +94,9 @@ function AbstractMCMC.step( steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) - # Create next sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index ad8fdad44..0f755988e 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +56,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +72,8 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) + return _optimize(f, args...; kwargs...) end """ @@ -104,8 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -127,8 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -144,9 +140,10 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) + return _optimize(f, args...; kwargs...) end + """ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) @@ -166,7 +163,9 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype + ) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -184,7 +183,7 @@ function _optimize( vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) logdensity_optimum = Optimisation.OptimLogDensity( - f.ldf.model, vi_optimum, f.ldf.context + f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype ) vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) diff --git a/src/Turing.jl b/src/Turing.jl index 1ff231017..0cdbe2458 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -71,7 +71,6 @@ using DynamicPPL: unfix, prefix, conditioned, - @submodel, to_submodel, LogDensityFunction, @addlogprob! @@ -81,7 +80,6 @@ using OrderedCollections: OrderedDict # Turing essentials - modelling macros and inference algorithms export # DEPRECATED - @submodel, generated_quantities, # Modelling - AbstractPPL and DynamicPPL @model, diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..07c8311b4 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -17,7 +17,8 @@ using DynamicPPL: setindex!!, push!!, setlogp!!, - getlogp, + getlogjoint, + getlogjoint_internal, VarName, getsym, getdist, @@ -26,9 +27,6 @@ using DynamicPPL: SampleFromPrior, SampleFromUniform, DefaultContext, - PriorContext, - LikelihoodContext, - SamplingContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors @@ -125,37 +123,111 @@ end ###################### # Default Transition # ###################### -# Default -getstats(t) = nothing +getstats(::Any) = NamedTuple() +# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition, +# SMCTransition, and PGTransition to Turing.Inference.Transition instead. abstract type AbstractTransition end -struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition +struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition θ::T - lp::F # TODO: merge `lp` with `stat` - stat::S -end + logprior::F + loglikelihood::F + stat::N + + """ + Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true) + + Construct a new `Turing.Inference.Transition` object using the outputs of a + sampler step. + + Here, `vi` represents a VarInfo _for which the appropriate parameters have + already been set_. However, the accumulators (e.g. logp) may in general + have junk contents. The role of this method is to re-evaluate `model` and + thus set the accumulators to the correct values. + + `sampler_transition` is the transition object returned by the sampler + itself and is only used to extract statistics of interest. + + By default, the model is re-evaluated in order to obtain values of: + - the values of the parameters as per user parameterisation (`vals_as_in_model`) + - the various components of the log joint probability (`logprior`, `loglikelihood`) + that are guaranteed to be correct. + + If you **know** for a fact that the VarInfo `vi` already contains this information, + then you can set `reevaluate=false` to skip the re-evaluation step. + + !!! warning + Note that in general this is unsafe and may lead to wrong results. + + If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that + the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`, + and `LogLikelihoodAccumulator` set up with the correct values. Note that the + `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it + must be set up to track `x := y` statements. + """ + function Transition( + model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true + ) + if reevaluate + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + ), + ) + _, vi = DynamicPPL.evaluate!!(model, vi) + end -Transition(θ, lp) = Transition(θ, lp, nothing) -function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) - θ = getparams(model, vi) - lp = getlogp(vi) - return Transition(θ, lp, getstats(t)) -end + # Extract all the information we need + vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + logprior = DynamicPPL.getlogprior(vi) + loglikelihood = DynamicPPL.getloglikelihood(vi) -function metadata(t::Transition) - stat = t.stat - if stat === nothing - return (lp=t.lp,) - else - return merge((lp=t.lp,), stat) + # Get additional statistics + stats = getstats(sampler_transition) + return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( + vals_as_in_model, logprior, loglikelihood, stats + ) end -end -DynamicPPL.getlogp(t::Transition) = t.lp + function Transition( + model::DynamicPPL.Model, + untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, + sampler_transition; + reevaluate=true, + ) + # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's + # much faster to convert it to a typed varinfo first, hence this method. + # https://github.com/TuringLang/Turing.jl/issues/2604 + return Transition( + model, + DynamicPPL.typed_varinfo(untyped_vi), + sampler_transition; + reevaluate=reevaluate, + ) + end +end -# Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) +function getstats_with_lp(t::Transition) + return merge( + t.stat, + ( + lp=t.logprior + t.loglikelihood, + logprior=t.logprior, + loglikelihood=t.loglikelihood, + ), + ) +end +function getstats_with_lp(vi::AbstractVarInfo) + return ( + lp=DynamicPPL.getlogjoint(vi), + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + ) +end ########################## # Chain making utilities # @@ -166,7 +238,7 @@ metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) # => value maps) from AbstractMCMC.getparams (defined for any sampler transition, # returns vector). """ - Turing.Inference.getparams(model::Any, t::Any) + Turing.Inference.getparams(model::DynamicPPL.Model, t::Any) Return a vector of parameter values from the given sampler transition `t` (i.e., the first return value of AbstractMCMC.step). By default, returns the `t.θ` field. @@ -175,35 +247,16 @@ the first return value of AbstractMCMC.step). By default, returns the `t.θ` fie This method only needs to be implemented for external samplers. It will be removed in future releases and replaced with `AbstractMCMC.getparams`. """ -getparams(model, t) = t.θ +getparams(::DynamicPPL.Model, t) = t.θ """ Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo) Return a key-value map of parameters from the varinfo. """ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. - # Unfortunately, using `invlink` can cause issues in scenarios where the constraints - # of the parameters change depending on the realizations. Hence we have to use - # `values_as_in_model`, which re-runs the model and extracts the parameters - # as they are seen in the model, i.e. in the constrained space. Moreover, - # this means that the code below will work both of linked and invlinked `vi`. - # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 - # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. - return DynamicPPL.values_as_in_model(model, true, deepcopy(vi)) + t = Transition(model, vi, nothing) + return getparams(model, t) end -function getparams( - model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata} -) - # values_as_in_model is unconscionably slow for untyped VarInfo. It's - # much faster to convert it to a typed varinfo before calling getparams. - # https://github.com/TuringLang/Turing.jl/issues/2604 - return getparams(model, DynamicPPL.typed_varinfo(untyped_vi)) -end -function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}}) - return Dict{VarName,Any}() -end - function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. @@ -219,7 +272,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) mapreduce(collect, vcat, iters) end - nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms @@ -234,14 +286,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) return names, vals end -function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogp(t) for t in ts], :, 1) - return [:lp], valmat -end - function get_transition_extras(ts::AbstractVector) - # Extract all metadata. - extra_data = map(metadata, ts) + # Extract stats + log probabilities from each transition or VarInfo + extra_data = map(getstats_with_lp, ts) return names_values(extra_data) end @@ -350,7 +397,7 @@ function AbstractMCMC.bundle_samples( vals = map(values(sym_to_vns)) do vns map(Base.Fix1(getindex, params), vns) end - return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t)) + return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t)) end end @@ -412,84 +459,4 @@ function DynamicPPL.get_matching_type( return Array{T,N} end -############## -# Utilities # -############## - -""" - - transitions_from_chain( - [rng::AbstractRNG,] - model::Model, - chain::MCMCChains.Chains; - sampler = DynamicPPL.SampleFromPrior() - ) - -Execute `model` conditioned on each sample in `chain`, and return resulting transitions. - -The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`. - -# Details - -In a bit more detail, the process is as follows: -1. For every `sample` in `chain` - 1. For every `variable` in `sample` - 1. Set `variable` in `model` to its value in `sample` - 2. Execute `model` with variables fixed as above, sampling variables NOT present - in `chain` using `SampleFromPrior` - 3. Return sampled variables and log-joint - -# Example -```julia-repl -julia> using Turing - -julia> @model function demo() - m ~ Normal(0, 1) - x ~ Normal(m, 1) - end; - -julia> m = demo(); - -julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` - -julia> transitions = Turing.Inference.transitions_from_chain(m, chain); - -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints -2-element Array{Float64,1}: - -3.6294991938628374 - -2.5697948166987845 - -julia> [first(t.θ.x) for t in transitions] # extract samples for `x` -2-element Array{Array{Float64,1},1}: - [-2.0844148956440796] - [-1.704630494695469] -``` -""" -function transitions_from_chain( - model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs... -) - return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) -end - -function transitions_from_chain( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - chain::MCMCChains.Chains; - sampler=DynamicPPL.SampleFromPrior(), -) - vi = Turing.VarInfo(model) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - transitions = map(iters) do (sample_idx, chain_idx) - # Set variables present in `chain` and mark those NOT present in chain to be resampled. - DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx) - model(rng, vi, sampler) - - # Convert `VarInfo` into `NamedTuple` and save. - Transition(model, vi) - end - - return transitions -end - end # module diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index fd4d441bd..4522875b4 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,7 +1,9 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - return DynamicPPL.check_model(model; error_on_failure=true) + # TODO(DPPL0.38/penelopeysm): use InitContext + spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) + return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) end function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) return _check_model(model) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index dfd1fc0d3..98ed20b40 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -53,22 +53,26 @@ function AbstractMCMC.step( length(initial_params) == n || throw(ArgumentError("initial parameters have to be specified for each walker")) vis = map(vis, initial_params) do vi, init + # TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!! vi = DynamicPPL.initialize_parameters!!(vi, init, model) # Update log joint probability. - last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior())) + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) + ) + last(DynamicPPL.evaluate!!(spl_model, vi)) end end # Compute initial transition and states. - transition = map(Base.Fix1(Transition, model), vis) + transition = [Transition(model, vi, nothing) for vi in vis] # TODO: Make compatible with immutable `AbstractVarInfo`. state = EmceeState( vis[1], map(vis) do vi vi = DynamicPPL.link!!(vi, model) - AMH.Transition(vi[:], getlogp(vi), false) + AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false) end, ) @@ -81,17 +85,19 @@ function AbstractMCMC.step( # Generate a log joint function. vi = state.vi densitymodel = AMH.DensityModel( - Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model, vi)) + Base.Fix1( + LogDensityProblems.logdensity, + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), + ), ) # Compute the next states. - states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)) + t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state vi = DynamicPPL.unflatten(vi, _state.params) - t = Transition(getparams(model, vi), _state.lp) - return t + return Transition(model, vi, t) end newstate = EmceeState(vi, states) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 544817348..3afd91607 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -24,56 +24,48 @@ struct ESS <: InferenceAlgorithm end # always accept in the first step function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] # define previous sampler state # (do not use cache to avoid in-place sampling from prior) - oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) + oldstate = EllipticalSliceSampling.ESSState(f, DynamicPPL.getloglikelihood(vi), nothing) # compute next state sample, state = AbstractMCMC.step( rng, - EllipticalSliceSampling.ESSModel( - ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction( - model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) - ), - ), + EllipticalSliceSampling.ESSModel(ESSPrior(model, vi), ESSLikelihood(model, vi)), EllipticalSliceSampling.ESS(), oldstate, ) # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setlogp!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end # Prior distribution of considered random variable -struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} +struct ESSPrior{M<:Model,V<:AbstractVarInfo,T} model::M - sampler::S varinfo::V μ::T - function ESSPrior{M,S,V}( - model::M, sampler::S, varinfo::V - ) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} + function ESSPrior(model::Model, varinfo::AbstractVarInfo) vns = keys(varinfo) μ = mapreduce(vcat, vns) do vn dist = getdist(varinfo, vn) @@ -81,47 +73,48 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} error("[ESS] only supports Gaussian prior distributions") DynamicPPL.tovec(mean(dist)) end - return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ) + return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ) end end -function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo) - return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo) -end - # Ensure that the prior is a Gaussian distribution (checked in the constructor) EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - sampler = p.sampler varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? + # TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model, + # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason + # why we had to use the 'del' flag before this was because + # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") end - p.model(rng, varinfo, sampler) + p.model(rng, varinfo) return varinfo[:] end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} +# Evaluate log-likelihood of proposals. We need this struct because +# EllipticalSliceSampling.jl expects a callable struct / a function as its +# likelihood. +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f) - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi -) - return DynamicPPL.tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi - ) + # Force usage of `getloglikelihood` in inner constructor + function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) + return new{typeof(ldf)}(ldf) + end end -function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) -end +(ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 1e0481b50..af31e0243 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -26,8 +26,6 @@ There are a few more optional functions which you can implement to improve the i - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`. - `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space. - -- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method. """ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: InferenceAlgorithm @@ -89,27 +87,21 @@ function externalsampler( return ExternalSampler(sampler, adtype, Val(unconstrained)) end -""" - getlogp_external(external_transition, external_state) - -Get the log probability density associated with the external sampler's -transition and state. Returns `missing` by default; in this case, an extra -model evaluation will be needed to calculate the correct log density. -""" -getlogp_external(::Any, ::Any) = missing -getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp -getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density - -struct TuringState{S,V1<:AbstractVarInfo,M,V,C} +# TODO(penelopeysm): Can't we clean this up somehow? +struct TuringState{S,V1,M,V} state::S - # Note that this varinfo has the correct parameters and logp obtained from - # the state, whereas `ldf.varinfo` will in general have junk inside it. + # Note that this varinfo must have the correct parameters set; but logp + # does not matter as it will be re-evaluated varinfo::V1 - ldf::DynamicPPL.LogDensityFunction{M,V,C} + # Note that in general the VarInfo inside this LogDensityFunction will have + # junk parameters and logp. It only exists to provide structure + ldf::DynamicPPL.LogDensityFunction{M,V} end -varinfo(state::TuringState) = state.varinfo -varinfo(state::AbstractVarInfo) = state +# get_varinfo should return something from which the correct parameters can be +# obtained, hence we use state.varinfo rather than state.ldf.varinfo +get_varinfo(state::TuringState) = state.varinfo +get_varinfo(state::AbstractVarInfo) = state getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState) @@ -119,24 +111,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params -function make_updated_varinfo( - f::DynamicPPL.LogDensityFunction, external_transition, external_state -) - # Set the parameters. - # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) - # The latter uses the state rather than the transition. - # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead - new_parameters = getparams(f.model, external_transition) - new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters) - # Set (or recalculate, if needed) the log density. - new_logp = getlogp_external(external_transition, external_state) - return if ismissing(new_logp) - last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context)) - else - DynamicPPL.setlogp!!(new_varinfo, new_logp) - end -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -163,7 +137,9 @@ function AbstractMCMC.step( end # Construct LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype + ) # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing @@ -181,13 +157,13 @@ function AbstractMCMC.step( ) end - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) + # The latter uses the state rather than the transition. + # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead + new_parameters = Turing.Inference.getparams(f.model, transition_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end @@ -206,12 +182,12 @@ function AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs... ) - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) + # The latter uses the state rather than the transition. + # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead + new_parameters = Turing.Inference.getparams(f.model, transition_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index f36cb9c36..692748767 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) # # Purpose: avoid triggering resampling of variables we're conditioning on. # - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`. # - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to # undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable # rather than only for the "true" observations. @@ -177,17 +177,21 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. - value, lp, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, lp, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( child_context, DynamicPPL.SampleFromPrior(), right, @@ -195,7 +199,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -208,14 +212,26 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, lp, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, lp, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else - value, lp, new_global_vi = DynamicPPL.tilde_assume( + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, DynamicPPL.SampleFromPrior(), @@ -224,7 +240,7 @@ function DynamicPPL.tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end @@ -327,7 +343,7 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -varinfo(state::GibbsState) = state.vi +get_varinfo(state::GibbsState) = state.vi """ Initialise a VarInfo for the Gibbs sampler. @@ -347,7 +363,7 @@ function initial_varinfo(rng, model, spl, initial_params) # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) end return vi end @@ -374,7 +390,7 @@ function AbstractMCMC.step( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -399,7 +415,7 @@ function AbstractMCMC.step_warmup( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -449,7 +465,7 @@ function gibbs_initialstep_recursive( initial_params=initial_params_local, kwargs..., ) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, get_global_varinfo(context)) @@ -477,7 +493,7 @@ function AbstractMCMC.step( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -487,7 +503,7 @@ function AbstractMCMC.step( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -497,7 +513,7 @@ function AbstractMCMC.step_warmup( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -507,7 +523,7 @@ function AbstractMCMC.step_warmup( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -525,16 +541,11 @@ function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:MH}, - state::AbstractVarInfo, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::MHState, params::AbstractVarInfo ) - # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. - # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated - # with `model.context` before hitting `model.f`. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) + # Re-evaluate to update the logprob. + new_vi = last(DynamicPPL.evaluate!!(model, params)) + return MHState(new_vi, DynamicPPL.getlogjoint_internal(new_vi)) end function setparams_varinfo!!( @@ -544,10 +555,8 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. To do this, we have to call evaluate!! with the sampler, rather - # than just a context, because ESS is peculiar in how it uses LikelihoodContext for - # some variables and DefaultContext for others. - return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) + # update its logprob. + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( @@ -557,7 +566,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params @@ -596,7 +605,7 @@ state for this sampler. This is relevant when multilple samplers are sampling th variables, and one might need it to be linked while the other doesn't. """ function match_linking!!(varinfo_local, prev_state_local, model) - prev_varinfo_local = varinfo(prev_state_local) + prev_varinfo_local = get_varinfo(prev_state_local) was_linked = DynamicPPL.istrans(prev_varinfo_local) is_linked = DynamicPPL.istrans(varinfo_local) if was_linked && !is_linked @@ -678,10 +687,10 @@ function gibbs_step_recursive( # Take a step with the local sampler. new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...)) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge the latest values for all the variables in the current sampler. new_global_vi = merge(get_global_varinfo(context), new_vi_local) - new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + new_global_vi = DynamicPPL.setlogp!!(new_global_vi, DynamicPPL.getlogp(new_vi_local)) new_states = (new_states..., new_state) return gibbs_step_recursive( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index b5f51587b..d80502f7e 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -25,7 +25,7 @@ end ### Hamiltonian Monte Carlo samplers. ### -varinfo(state::HMCState) = state.vi +get_varinfo(state::HMCState) = state.vi """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end @@ -191,14 +193,7 @@ function DynamicPPL.initialstep( metricT = getmetricT(spl.alg) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -213,9 +208,6 @@ function DynamicPPL.initialstep( end theta = vi[:] - # Cache current log density. - log_density_old = getlogp(vi) - # Find good eps if not provided one if iszero(spl.alg.ϵ) ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta) @@ -239,14 +231,13 @@ function DynamicPPL.initialstep( ) end - # Update `vi` based on acceptance - if t.stat.is_accept - vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # Update VarInfo parameters based on acceptance + new_params = if t.stat.is_accept + t.z.θ else - vi = DynamicPPL.unflatten(vi, theta) - vi = setlogp!!(vi, log_density_old) + theta end + vi = DynamicPPL.unflatten(vi, new_params) transition = Transition(model, vi, t) state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) @@ -290,7 +281,6 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) end # Compute next transition and state. @@ -303,14 +293,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -516,10 +499,6 @@ function DynamicPPL.assume( return DynamicPPL.assume(dist, vn, vi) end -function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi) - return DynamicPPL.observe(d, value, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index d83abd173..319e424fc 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -31,19 +31,19 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) - return Transition(model, vi), nothing + return Transition(model, vi, nothing), nothing end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) vi = VarInfo(rng, model, spl) - return Transition(model, vi), nothing + return Transition(model, vi, nothing), nothing end # Calculate evidence. function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) - return logsumexp(map(x -> x.lp, samples)) - log(length(samples)) + return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi) @@ -53,9 +53,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName r = rand(rng, dist) vi = push!!(vi, vn, r, dist) end - return r, 0, vi -end - -function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi) - return logpdf(dist, value), vi + vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) + return r, vi end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index fb50c5f58..863db559c 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -153,10 +153,33 @@ function MH(model::Model; proposal_type=AMH.StaticProposal) return AMH.MetropolisHastings(priors) end +""" + MHState(varinfo::AbstractVarInfo, logjoint_internal::Real) + +State for Metropolis-Hastings sampling. + +`varinfo` must have the correct parameters set inside it, but its other fields +(e.g. accumulators, which track logp) can in general be missing or incorrect. + +`logjoint_internal` is the log joint probability of the model, evaluated using +the parameters and linking status of `varinfo`. It should be equal to +`DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the +MH sampler so we store this here to avoid re-evaluating the model +unnecessarily. +""" +struct MHState{V<:AbstractVarInfo,L<:Real} + varinfo::V + logjoint_internal::L +end + +get_varinfo(s::MHState) = s.varinfo + ##################### # Utility functions # ##################### +# TODO(DPPL0.38/penelopeysm): This function should no longer be needed +# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -181,21 +204,19 @@ function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTup end end -""" - MHLogDensityFunction - -A log density function for the MH sampler. - -This variant uses the `set_namedtuple!` function to update the `VarInfo`. -""" -const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} - -function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) +# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems +# interface in that it gets evaluated with a NamedTuple. Hence we need this +# method just to deal with MH. +# TODO(DPPL0.38/penelopeysm): Check the extent to which this method is actually +# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, +# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). +# In general, we should much prefer to either (1) conform to the +# LogDensityProblems interface or (2) use VarNames anyway. +function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) - vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) - lj = getlogp(vi_new) + vi_new = last(DynamicPPL.evaluate!!(f.model, vi)) + lj = f.getlogdensity(vi_new) return lj end @@ -297,64 +318,71 @@ end # Make a proposal if we don't have a covariance proposal matrix (the default). function propose!!( - rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal + rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal ) + vi = prev_state.varinfo # Retrieve distribution and value NamedTuples. dt, vt = dist_val_tuple(spl, vi) # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(dt) - prev_trans = AMH.Transition(vt, getlogp(vi), false) + prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - # TODO: Make this compatible with immutable `VarInfo`. - # Update the values in the VarInfo. + # trans.params isa NamedTuple set_namedtuple!(vi, trans.params) - return setlogp!!(vi, trans.lp) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end # Make a proposal if we DO have a covariance proposal matrix. function propose!!( rng::AbstractRNG, - vi::AbstractVarInfo, + prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal::AdvancedMH.RandomWalkProposal, ) + vi = prev_state.varinfo # If this is the case, we can just draw directly from the proposal # matrix. vals = vi[:] # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) - prev_trans = AMH.Transition(vals, getlogp(vi), false) + prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp) + # trans.params isa AbstractVector + vi = DynamicPPL.unflatten(vi, trans.params) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end function DynamicPPL.initialstep( @@ -368,18 +396,18 @@ function DynamicPPL.initialstep( # just link everything before sampling. vi = maybe_link!!(vi, spl, spl.alg.proposals, model) - return Transition(model, vi), vi + return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs... ) # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space - vi = propose!!(rng, vi, model, spl, spl.alg.proposals) + new_state = propose!!(rng, state, model, spl, spl.alg.proposals) - return Transition(model, vi), vi + return Transition(model, new_state.varinfo, nothing), new_state end #### @@ -392,7 +420,3 @@ function DynamicPPL.assume( retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) return retval end - -function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) - return DynamicPPL.observe(SampleFromPrior(), d, value, vi) -end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index a81f436c8..6959e22cc 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -4,6 +4,38 @@ ### AdvancedPS models and interface +""" + set_all_del!(vi::AbstractVarInfo) + +Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for +resampling. +""" +function set_all_del!(vi::AbstractVarInfo) + # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we + # could either: + # - keep a boolean 'resample' flag on the trace, or + # - modify the model context appropriately. + # However, this refactoring will have to wait until InitContext is + # merged into DPPL. + for vn in keys(vi) + DynamicPPL.set_flag!(vi, vn, "del") + end + return nothing +end + +""" + unset_all_del!(vi::AbstractVarInfo) + +Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing +them from being resampled. +""" +function unset_all_del!(vi::AbstractVarInfo) + for vn in keys(vi) + DynamicPPL.unset_flag!(vi, vn, "del") + end + return nothing +end + struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M @@ -18,43 +50,44 @@ function TracedModel( varinfo::AbstractVarInfo, rng::Random.AbstractRNG, ) - context = SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) + spl_model = DynamicPPL.contextualize(model, spl_context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) if kwargs !== nothing && !isempty(kwargs) error( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end - evaluator = (model.f, args...) - return TracedModel(model, sampler, varinfo, evaluator) + evaluator = (spl_model.f, args...) + return TracedModel(spl_model, sampler, varinfo, evaluator) end function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - DynamicPPL.increment_num_produce!(trace.model.f.varinfo) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) - if score === nothing - return nothing - else - return score + DynamicPPL.getlogp(trace.model.f.varinfo) - end + return score end function AdvancedPS.delete_retained!(trace::TracedModel) - DynamicPPL.set_retained_vns_del!(trace.varinfo) + # This method is called if, during a CSMC update, we perform a resampling + # and choose the reference particle as the trajectory to carry on from. + # In such a case, we need to ensure that when we continue sampling (i.e. + # the next time we hit tilde_assume), we don't use the values in the + # reference particle but rather sample new values. + # + # Here, we indiscriminately set the 'del' flag for all variables in the + # VarInfo. This is slightly overkill: it is not necessary to set the 'del' + # flag for variables that were already sampled. However, it allows us to + # avoid keeping track of which variables were sampled, which leads to many + # simplifications in the VarInfo data structure. + set_all_del!(trace.varinfo) return trace end function AdvancedPS.reset_model(trace::TracedModel) - DynamicPPL.reset_num_produce!(trace.varinfo) - return trace -end - -function AdvancedPS.reset_logprob!(trace::TracedModel) - DynamicPPL.resetlogp!!(trace.model.varinfo) return trace end @@ -113,17 +146,11 @@ end function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) - + lp = DynamicPPL.getlogjoint_internal(vi) return SMCTransition(theta, lp, weight) end -metadata(t::SMCTransition) = (lp=t.lp, weight=t.weight) - -DynamicPPL.getlogp(t::SMCTransition) = t.lp +getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight) struct SMCState{P,F<:AbstractFloat} particles::P @@ -182,10 +209,10 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) - DynamicPPL.empty!!(vi) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) + set_all_del!(vi) + vi = DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( @@ -288,21 +315,15 @@ struct PGState rng::Random.AbstractRNG end -varinfo(state::PGState) = state.vi +get_varinfo(state::PGState) = state.vi function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) - + lp = DynamicPPL.getlogjoint_internal(vi) return PGTransition(theta, lp, logevidence) end -metadata(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) - -DynamicPPL.getlogp(t::PGTransition) = t.lp +getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState) return mean(x.logevidence for x in samples) @@ -315,10 +336,10 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) + set_all_del!(vi) + vi = DynamicPPL.resetlogp!!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -348,14 +369,15 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) + vi = DynamicPPL.resetlogp!!(vi) # Create reference particle for which the samples will be retained. + unset_all_del!(vi) reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) # For all other particles, do not retain the variables but resample them. - DynamicPPL.set_retained_vns_del!(vi) + set_all_del!(vi) # Create a new set of particles. num_particles = spl.alg.nparticles @@ -384,77 +406,118 @@ function AbstractMCMC.step( end function DynamicPPL.use_threadsafe_eval( - ::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo + ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo ) return false end -function trace_local_varinfo_maybe(varinfo) - try - trace = Libtask.get_taped_globals(Any).other - return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo +""" + get_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Get the `Trace` local varinfo if one exists. + +If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) + trace = try + Libtask.get_taped_globals(Any).other catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return varinfo - else - rethrow(e) - end + e == KeyError(:task_variable) ? nothing : rethrow(e) end + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo end -function trace_local_rng_maybe(rng::Random.AbstractRNG) - try - return Libtask.get_taped_globals(Any).rng +""" + get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) + +Get the `Trace` local rng if one exists. + +If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_rng_maybe(rng::Random.AbstractRNG) + return try + Libtask.get_taped_globals(Any).rng catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return rng - else - rethrow(e) - end + e == KeyError(:task_variable) ? rng : rethrow(e) end end +""" + set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`. + +If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the +task. Otherwise do nothing. +""" +function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + # TODO(mhauru) This should be done in a try-catch block, as in the commented out code. + # However, Libtask currently can't handle this block. + trace = #try + Libtask.get_taped_globals(Any).other + # catch e + # e == KeyError(:task_variable) ? nothing : rethrow(e) + # end + if trace !== nothing + model = trace.model + model = Accessors.@set model.f.varinfo = vi + trace.model = model + end + return nothing +end + function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo + rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - vi = trace_local_varinfo_maybe(_vi) - trng = trace_local_rng_maybe(rng) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + trng = get_trace_local_rng_maybe(rng) if ~haskey(vi, vn) r = rand(trng, dist) - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent + # TODO(mhauru): + # The below is the only line that differs from assume called on SampleFromPrior. + # Could we just call assume on SampleFromPrior with a specific rng? r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? - lp = 0 - return r, lp, vi -end -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. - return logpdf(dist, value), trace_local_varinfo_maybe(vi) -end + vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) -function DynamicPPL.acclogp!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -) - varinfo_trace = trace_local_varinfo_maybe(varinfo) - return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end + return r, vi end -function DynamicPPL.acclogp_observe!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp +function DynamicPPL.tilde_observe!!( + ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi ) - Libtask.produce(logp) - return trace_local_varinfo_maybe(varinfo) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end + return left, vi end # Convenient constructor @@ -465,20 +528,74 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - DynamicPPL.reset_num_produce!(newvarinfo) - tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) return newtrace end +""" + ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value. + +# Fields +$(TYPEDFIELDS) +""" +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} + "the scalar log likelihood value" + logp::T +end + +# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two +# can be used in a given VarInfo. +DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood +DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp + +function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val) + # The below line is the only difference from `LogLikelihoodAccumulator`. + Libtask.produce(val) + return ProduceLogLikelihoodAccumulator(acc1.logp + val) +end + +function DynamicPPL.accumulate_assume!!( + acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right +) + return acc +end +function DynamicPPL.accumulate_observe!!( + acc::ProduceLogLikelihoodAccumulator, right, left, vn +) + return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left)) +end + # We need to tell Libtask which calls may have `produce` calls within them. In practice most # of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with -# `acclogp_observe!!` which is what calls `produce` and go up the call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the +# call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{ + <:Tuple{ + typeof(Base.:+), + ProduceLogLikelihoodAccumulator, + DynamicPPL.LogLikelihoodAccumulator, + }, + }, +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} +) + return true +end Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? +# That's the only thing that makes tilde_assume calls result in tilde_observe calls. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c7a5cc573..db4e0466d 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -8,18 +8,53 @@ struct Prior <: InferenceAlgorithm end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:Prior}, - state=nothing; + sampler::DynamicPPL.Sampler{<:Prior}; kwargs..., ) - vi = last( - DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), + # TODO(DPPL0.38/penelopeysm): replace with init!! + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) + ) + vi = VarInfo() + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), ), ) - return vi, nothing + _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + vi = DynamicPPL.typed_varinfo(vi) + return Transition(model, vi, nothing; reevaluate=false), vi end -DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:Prior}, + vi::AbstractVarInfo; + kwargs..., +) + # TODO(DPPL0.38/penelopeysm): replace this entire thing with init!! + # + # `vi` is a VarInfo from the previous step so already has all the + # right accumulators and stuff. The only thing we need to change is to make + # sure that the old values are overwritten when we resample. + # + # Note also that the values in the Transition (and hence the chain) are not + # obtained from the VarInfo's metadata itself, but are instead obtained + # from the ValuesAsInModelAccumulator, which is cleared in the evaluate!! + # call. Thus, the actual values in the VarInfo's metadata don't matter: + # we only set the del flag here to make sure that new values are sampled + # (and thus new values enter VAIMAcc), rather than the old ones being + # reused during the evaluation. Yes, SampleFromPrior really sucks. + for vn in keys(vi) + DynamicPPL.set_flag!(vi, vn, "del") + end + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) + ) + _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + return Transition(model, vi, nothing; reevaluate=false), vi +end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 0c322244e..5ca351643 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -58,19 +58,15 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) end # Compute initial sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGHMCState(ℓ, vi, zero(vi[:])) @@ -98,12 +94,11 @@ function AbstractMCMC.step( α = spl.alg.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) # Compute next sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) newstate = SGHMCState(ℓ, vi, newv) return sample, newstate @@ -200,13 +195,11 @@ end function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) theta = getparams(model, vi) - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint_internal(vi) return SGLDTransition(theta, lp, stepsize) end -metadata(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) - -DynamicPPL.getlogp(t::SGLDTransition) = t.lp +getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) struct SGLDState{L,V<:AbstractVarInfo} logdensity::L @@ -221,19 +214,15 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) end # Create first sample and state. sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0))) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGLDState(ℓ, vi, 1) @@ -252,9 +241,8 @@ function AbstractMCMC.step( stepsize = spl.alg.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) # Compute next sample and state. sample = SGLDTransition(model, vi, stepsize) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index fedc2510d..19c52c381 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,75 +43,24 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end -""" - OptimizationContext{C<:AbstractContext} <: AbstractContext - -The `OptimizationContext` transforms variables to their constrained space, but -does not use the density with respect to the transformation. This context is -intended to allow an optimizer to sample in R^n freely. -""" -struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - context::C - - function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !( - context isa Union{ - DynamicPPL.DefaultContext, - DynamicPPL.LikelihoodContext, - DynamicPPL.PriorContext, - } - ) - msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, - and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` - """ - throw(ArgumentError(msg)) - end - return new{C}(context) - end -end - -OptimizationContext(ctx::DynamicPPL.AbstractContext) = OptimizationContext{typeof(ctx)}(ctx) - -DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() - -function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) - r = vi[vn, dist] - lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} - # MAP - Distributions.logpdf(dist, r) - else - # MLE - 0 - end - return r, lp, vi -end - -function DynamicPPL.tilde_observe( - ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args... -) - return DynamicPPL.tilde_observe(ctx.context, args...) -end - """ OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType + F<:Function, + V<:DynamicPPL.AbstractVarInfo, + AD<:ADTypes.AbstractADType, } A struct that wraps a single LogDensityFunction. Can be invoked either using ```julia -OptimLogDensity(model, varinfo, ctx; adtype=adtype) +OptimLogDensity(model, varinfo; adtype=adtype) ``` or ```julia -OptimLogDensity(model, ctx; adtype=adtype) +OptimLogDensity(model; adtype=adtype) ``` If not specified, `adtype` defaults to `AutoForwardDiff()`. @@ -129,37 +78,35 @@ the underlying LogDensityFunction at the point `z`. This is done to satisfy the Optim.jl interface. ```julia -optim_ld = OptimLogDensity(model, varinfo, ctx) +optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ -struct OptimLogDensity{ - M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, - AD<:ADTypes.AbstractADType, -} - ldf::DynamicPPL.LogDensityFunction{M,V,C,AD} -end +struct OptimLogDensity{L<:DynamicPPL.LogDensityFunction} + ldf::L -function OptimLogDensity( - model::DynamicPPL.Model, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi, ctx; adtype=adtype)) -end - -# No varinfo -function OptimLogDensity( - model::DynamicPPL.Model, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype) + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.AbstractVarInfo; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + ) + ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + return new{typeof(ldf)}(ldf) + end + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) + # No varinfo + return OptimLogDensity( + model, + getlogdensity, + DynamicPPL.ldf_default_varinfo(model, getlogdensity); + adtype=adtype, + ) + end end """ @@ -325,10 +272,13 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. - linked = DynamicPPL.istrans(m.f.ldf.varinfo) + old_ldf = m.f.ldf + linked = DynamicPPL.istrans(old_ldf.varinfo) if linked - new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) + new_f = OptimLogDensity( + old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -339,8 +289,11 @@ function StatsBase.informationmatrix( # Link it back if we invlinked it. if linked - new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + invlinked_ldf = m.f.ldf + new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) + new_f = OptimLogDensity( + invlinked_ldf.model, old_ldf.getlogdensity, new_vi; adtype=invlinked_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -553,7 +506,12 @@ function estimate_mode( ub=nothing, kwargs..., ) - check_model && DynamicPPL.check_model(model; error_on_failure=true) + if check_model + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true) + end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) initial_params = generate_initial_params(model, initial_params, constraints) @@ -563,19 +521,17 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - inner_context = if estimator isa MAP - DynamicPPL.DefaultContext() - else - DynamicPPL.LikelihoodContext() - end - ctx = OptimizationContext(inner_context) + # Note that we use `getlogjoint` rather than `getlogjoint_internal`: this + # is intentional, because even though the VarInfo may be linked, the + # optimisation target should not take the Jacobian term into account. + getlogdensity = estimator isa MAP ? DynamicPPL.getlogjoint : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the # varinfo are completely ignored. The parameters only matter if you are calling evaluate!! # directly on the fields of the LogDensityFunction - vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.ldf_default_varinfo(model, getlogdensity) vi = DynamicPPL.unflatten(vi, initial_params) # Link the varinfo if needed. @@ -588,7 +544,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi, ctx) + log_density = OptimLogDensity(model, getlogdensity, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9428af11..d51631968 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -17,12 +17,6 @@ export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") -function make_logdensity(model::DynamicPPL.Model) - weight = 1.0 - ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) -end - """ q_initialize_scale( [rng::Random.AbstractRNG,] @@ -68,7 +62,7 @@ function q_initialize_scale( num_max_trials::Int=10, reduce_factor::Real=one(eltype(scale)) / 2, ) - prob = make_logdensity(model) + prob = LogDensityFunction(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) varinfo = DynamicPPL.VarInfo(model) @@ -309,7 +303,7 @@ function vi( ) return AdvancedVI.optimize( rng, - make_logdensity(model), + LogDensityFunction(model), objective, q, n_iterations; diff --git a/test/Project.toml b/test/Project.toml index 46817c6c5..149c7336b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36.12" +DynamicPPL = "0.37" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5..dcfe4ef46 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -155,35 +155,33 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi ) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) check_adtype(context, vi) - return logp, vi + return left, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) + left, vi = DynamicPPL.tilde_observe!!( DynamicPPL.childcontext(context), sampler, right, left, vi ) check_adtype(context, vi) - return logp, vi + return left, vi end """ @@ -239,32 +237,8 @@ end end end -@testset verbose = true "AD / SamplingContext" begin - # AD tests for gradient-based samplers need to be run with SamplingContext - # because samplers can potentially use this to define custom behaviour in - # the tilde-pipeline and thus change the code executed during model - # evaluation. - @testset "adtype=$adtype" for adtype in ADTYPES - @testset "alg=$alg" for alg in [ - HMC(0.1, 10; adtype=adtype), - HMCDA(0.8, 0.75; adtype=adtype), - NUTS(1000, 0.8; adtype=adtype), - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any - end - end - end -end - @testset verbose = true "AD / GibbsContext" begin - # Gibbs sampling also needs extra AD testing because the models are + # Gibbs sampling needs some extra AD testing because the models are # executed with GibbsContext and a subsetted varinfo. (see e.g. # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in # src/mcmc/gibbs.jl -- the code here mimics what happens in those @@ -283,8 +257,7 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + @test run_ad(model, adtype; test=true, benchmark=false) isa Any end end end diff --git a/test/essential/container.jl b/test/essential/container.jl index cbd7a6fe2..124637aab 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -19,6 +19,7 @@ using Turing @testset "constructor" begin vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) @@ -27,14 +28,11 @@ using Turing @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 @test res ≈ -log(2) # Catch broken copy, espetially for RNG / VarInfo newtrace = AdvancedPS.fork(trace) res2 = AdvancedPS.advance!(trace) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 2 - @test DynamicPPL.get_num_produce(newtrace.model.f.varinfo) == 1 end @testset "fork" begin @@ -46,6 +44,7 @@ using Turing return a, b end vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = normal() diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index c3ac571cb..2cc7e4bc0 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -5,7 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL -using DynamicPPL: Sampler, getlogp +using DynamicPPL: Sampler import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -113,44 +113,12 @@ using Turing check_gdemo(chn3_contd) end - @testset "Contexts" begin - # Test LikelihoodContext - @model function testmodel1(x) - a ~ Beta() - lp1 = getlogp(__varinfo__) - x[1] ~ Bernoulli(a) - return global loglike = getlogp(__varinfo__) - lp1 - end - model = testmodel1([1.0]) - varinfo = DynamicPPL.VarInfo(model) - model(varinfo, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - @test getlogp(varinfo) == loglike - - # Test MiniBatchContext - @model function testmodel2(x) - a ~ Beta() - return x[1] ~ Bernoulli(a) - end - model = testmodel2([1.0]) - varinfo1 = DynamicPPL.VarInfo(model) - varinfo2 = deepcopy(varinfo1) - model(varinfo1, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - model( - varinfo2, - DynamicPPL.SampleFromPrior(), - DynamicPPL.MiniBatchContext(DynamicPPL.LikelihoodContext(), 10), - ) - @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) - end - @testset "Prior" begin N = 10_000 - # Note that all chains contain 3 values per sample: 2 variables + log probability @testset "Single-threaded vanilla" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), N) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 1) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -158,7 +126,6 @@ using Turing @testset "Multi-threaded" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), MCMCThreads(), N, 4) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 4) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -169,25 +136,34 @@ using Turing ) @test chains isa Vector{<:NamedTuple} @test length(chains) == N - @test all(length(x) == 3 for x in chains) @test all(haskey(x, :lp) for x in chains) + @test all(haskey(x, :logprior) for x in chains) + @test all(haskey(x, :loglikelihood) for x in chains) @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end - @testset "#2169" begin - # Not exactly the same as the issue, but similar. - @model function issue2169_model() - if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext - x ~ Normal(0, 1) - else - x ~ Normal(1000, 1) - end + @testset "accumulators are set correctly" begin + # Prior() uses `reevaluate=false` when constructing a + # `Turing.Inference.Transition`, so we had better make sure that it + # does capture colon-eq statements, as we can't rely on the default + # `Transition` constructor to do this for us. + @model function coloneq() + x ~ Normal() + 10.0 ~ Normal(x) + z := 1.0 + return nothing end - - model = issue2169_model() - chain = sample(StableRNG(seed), model, Prior(), 10) - @test all(mean(chain[:x]) .< 5) + chain = sample(coloneq(), Prior(), N) + @test chain isa MCMCChains.Chains + @test all(x -> x == 1.0, chain[:z]) + # And for the same reason we should also make sure that the logp + # components are correctly calculated. + @test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x])) + @test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0)) + @test isapprox(chain[:lp], chain[:logprior] .+ chain[:loglikelihood]) + # And that the outcome is not influenced by the likelihood + @test mean(chain, :x) ≈ 0.0 atol = 0.1 end end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index e918b3a51..1e1be9b45 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -54,7 +54,7 @@ using Turing @testset "gdemo with CSMC + ESS" begin alg = Gibbs(:s => CSMC(15), :m => ESS()) - chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) + chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index bc682c09b..38b9b0660 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -86,12 +86,10 @@ using Turing.Inference: AdvancedHMC @test chn isa MCMCChains.Chains @test all(chn[:a] .== a) @test all(chn[:b] .== b) - # TODO: Uncomment this once Turing v0.40 is released. In that version, logpdf - # will be recalculated correctly for external samplers. - # expected_logpdf = logpdf(Beta(2, 2), a) + logpdf(Normal(a), b) - # @test all(chn[:lp] .== expected_logpdf) - # @test all(chn[:logprior] .== expected_logpdf) - # @test all(chn[:loglikelihood] .== 0.0) + expected_logpdf = logpdf(Beta(2, 2), a) + logpdf(Normal(a), b) + @test all(chn[:lp] .== expected_logpdf) + @test all(chn[:logprior] .== expected_logpdf) + @test all(chn[:loglikelihood] .== 0.0) end function initialize_nuts(model::DynamicPPL.Model) @@ -100,7 +98,9 @@ function initialize_nuts(model::DynamicPPL.Model) linked_vi = DynamicPPL.link!!(vi, model) # Create a LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, linked_vi; adtype=Turing.DEFAULT_ADTYPE) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, linked_vi; adtype=Turing.DEFAULT_ADTYPE + ) # Choose parameter dimensionality and initial parameter value D = LogDensityProblems.dimension(f) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..0fd76be3a 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -207,8 +207,8 @@ end val ~ Normal(s, 1) 1.0 ~ Normal(s + m, 1) - n := m + 1 - xs = M(undef, n) + n := m + xs = M(undef, 5) for i in eachindex(xs) xs[i] ~ Beta(0.5, 0.5) end @@ -295,7 +295,7 @@ end vi::T end - Turing.Inference.varinfo(state::VarInfoState) = state.vi + Turing.Inference.get_varinfo(state::VarInfoState) = state.vi function Turing.Inference.setparams_varinfo!!( ::DynamicPPL.Model, ::DynamicPPL.Sampler, @@ -312,8 +312,8 @@ end kwargs..., ) spl.alg.non_warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step_warmup( @@ -323,30 +323,30 @@ end kwargs..., ) spl.alg.warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.non_warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end @model f() = x ~ Normal() @@ -565,40 +565,98 @@ end end end - # The below test used to sample incorrectly before - # https://github.com/TuringLang/Turing.jl/pull/2328 - @testset "dynamic model with ESS" begin - @model function dynamic_model_for_ess() - b ~ Bernoulli() - x_length = b ? 1 : 2 - x = Vector{Float64}(undef, x_length) - for i in 1:x_length - x[i] ~ Normal(i, 1.0) + @testset "PG with variable number of observations" begin + # When sampling from a model with Particle Gibbs, it is mandatory for + # the number of observations to be the same in all particles, since the + # observations trigger particle resampling. + # + # Up until Turing v0.39, `x ~ dist` statements where `x` was the + # responsibility of a different (non-PG) Gibbs subsampler used to not + # count as an observation. Instead, the log-probability `logpdf(dist, x)` + # would be manually added to the VarInfo's `logp` field and included in the + # weighting for the _following_ observation. + # + # In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!! + # which thus triggers resampling. Thus, for example, the following model + # does not work any more: + # + # @model function f() + # a ~ Poisson(2.0) + # x = Vector{Float64}(undef, a) + # for i in eachindex(x) + # x[i] ~ Normal() + # end + # end + # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) + # + # because the number of observations in each particle depends on the value + # of `a`. + # + # This testset checks that ways of working around such a situation. + + function test_dynamic_bernoulli(chain) + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) + for vn in keys(means) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) end end - m = dynamic_model_for_ess() - chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) - means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) - stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) - for vn in keys(means) - @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) - @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + # TODO(DPPL0.37/penelopeysm): decide what to do with these tests + @testset "Coalescing multiple observations into one" begin + # Instead of observing x[1] and x[2] separately, we lump them into a + # single distribution. + @model function dynamic_bernoulli() + b ~ Bernoulli() + if b + dists = [Normal(1.0)] + else + dists = [Normal(1.0), Normal(2.0)] + end + return x ~ product_distribution(dists) + end + model = dynamic_bernoulli() + # This currently fails because if the global varinfo has `x` with length 2, + # and the particle sampler has `b = true`, it attempts to calculate the + # log-likelihood of a length-2 vector with respect to a length-1 + # distribution. + @test_throws DimensionMismatch chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(10), :x => ESS()), + 2000; + discard_initial=100, + ) + # test_dynamic_bernoulli(chain) end - end - @testset "dynamic model with dot tilde" begin - @model function dynamic_model_with_dot_tilde( - num_zs=10, (::Type{M})=Vector{Float64} - ) where {M} - z = Vector{Int}(undef, num_zs) - z .~ Poisson(1.0) - num_ms = sum(z) - m = M(undef, num_ms) - return m .~ Normal(1.0, 1.0) - end - model = dynamic_model_with_dot_tilde() - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100) + @testset "Inserting @addlogprob!" begin + # On top of observing x[i], we also add in extra 'observations' + @model function dynamic_bernoulli_2() + b ~ Bernoulli() + x_length = b ? 1 : 2 + x = Vector{Float64}(undef, x_length) + for i in 1:x_length + x[i] ~ Normal(i, 1.0) + end + if length(x) == 1 + # This value is the expectation value of logpdf(Normal(), x) where x ~ Normal(). + # See discussion in + # https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817 + @addlogprob!(-1.418849) + end + end + model = dynamic_bernoulli_2() + chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(10), :x => ESS()), + 2000; + discard_initial=100, + ) + test_dynamic_bernoulli(chain) + end end @testset "Demo model" begin @@ -828,7 +886,9 @@ end function check_logp_correct(sampler) @testset "logp is set correctly" begin @model logp_check() = x ~ Normal() - chn = sample(logp_check(), Gibbs(@varname(x) => sampler), 100) + chn = sample( + logp_check(), Gibbs(@varname(x) => sampler), 100; progress=false + ) @test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp]) end end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 8832e5fe7..839dffbbe 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,23 +171,6 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end - @testset "prior" begin - # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance - # which means that it's _very_ difficult to find a good tolerance in the test below:) - prior_dist = truncated(Normal(3, 1); lower=0) - - @model function demo_hmc_prior() - s ~ prior_dist - return m ~ Normal(0, sqrt(s)) - end - alg = NUTS(1000, 0.8) - gdemo_default_prior = DynamicPPL.contextualize( - demo_hmc_prior(), DynamicPPL.PriorContext() - ) - chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2) - end - @testset "warning for difficult init params" begin attempt = 0 @model function demo_warn_initial_params() diff --git a/test/mcmc/is.jl b/test/mcmc/is.jl index 44fbe9201..2811e9c86 100644 --- a/test/mcmc/is.jl +++ b/test/mcmc/is.jl @@ -47,11 +47,11 @@ using Turing Random.seed!(seed) chain = sample(model, alg, n; check_model=false) - sampled = get(chain, [:a, :b, :lp]) + sampled = get(chain, [:a, :b, :loglikelihood]) @test vec(sampled.a) == ref.as @test vec(sampled.b) == ref.bs - @test vec(sampled.lp) == ref.logps + @test vec(sampled.loglikelihood) == ref.logps @test chain.logevidence == ref.logevidence end diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index add2e7404..70810e164 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -262,24 +262,6 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end - @testset "prior" begin - alg = MH() - gdemo_default_prior = DynamicPPL.contextualize( - gdemo_default, DynamicPPL.PriorContext() - ) - burnin = 10_000 - n = 10_000 - chain = sample( - StableRNG(seed), - gdemo_default_prior, - alg, - n; - discard_initial=burnin, - thinning=10, - ) - check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0]; atol=0.3) - end - @testset "`filldist` proposal (issue #2180)" begin @model demo_filldist_issue2180() = x ~ MvNormal(zeros(3), I) chain = sample( diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 2acb7edc5..6b93e7629 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -24,28 +24,7 @@ using Turing hasstats(result) = result.optim_result.stats !== nothing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 - @testset "OptimizationContext" begin - # Used for testing how well it works with nested contexts. - struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext - context::C - logprior_weight::T1 - loglikelihood_weight::T2 - end - DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent() - DynamicPPL.childcontext(parent::OverrideContext) = parent.context - DynamicPPL.setchildcontext(parent::OverrideContext, child) = - OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight) - - # Only implement what we need for the models above. - function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, context.logprior_weight, vi - end - function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return context.loglikelihood_weight, vi - end - + @testset "OptimLogDensity and contexts" begin @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -62,48 +41,36 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + # Doesn't matter if we use getlogjoint or getlogjoint_internal since the + # VarInfo isn't linked. + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) + @test ld1(w) == ld2(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) + @test ld1(w) == ld2(w) end - @testset "Weighted" begin - function override(model) - return DynamicPPL.contextualize( - model, OverrideContext(model.context, 100, 1) - ) - end - m1 = override(model1(x)) - m2 = override(model2() | (x=x,)) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) - end - - @testset "Default, Likelihood, Prior Contexts" begin + @testset "Joint, prior, and likelihood" begin m1 = model1(x) - defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) a = [0.3] + ld_joint = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld_prior = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior) + ld_likelihood = Turing.Optimisation.OptimLogDensity( + m1, DynamicPPL.getloglikelihood + ) + @test ld_joint(a) == ld_prior(a) + ld_likelihood(a) - @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == - Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + - Turing.Optimisation.OptimLogDensity(m1, prictx)(a) - - # test that PriorContext is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ + # test that the prior accumulator is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -651,8 +618,7 @@ using Turing return nothing end m = saddle_model() - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx) + optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m)