diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 8c318b419..9e4c8b6ef 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -44,10 +44,10 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} stepsize::S end -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:DynamicNUTS}, + spl::DynamicNUTS, vi::DynamicPPL.AbstractVarInfo; kwargs..., ) @@ -59,7 +59,7 @@ function DynamicPPL.initialstep( # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) # Perform initial step. @@ -80,14 +80,14 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:DynamicNUTS}, + spl::DynamicNUTS, state::DynamicNUTSState; kwargs..., ) # Compute next sample. vi = state.vi ℓ = state.logdensity - steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) + steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) # Create next sample and state. diff --git a/src/Turing.jl b/src/Turing.jl index b3412cf55..58a58eb2a 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -160,6 +160,8 @@ export maximum_a_posteriori, maximum_likelihood, MAP, - MLE + MLE, + # Chain save/resume + loadstate end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 883ba15a5..7d25ecd7e 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -22,7 +22,6 @@ using DynamicPPL: getsym, getdist, Model, - Sampler, DefaultContext using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate @@ -50,8 +49,7 @@ import Random import MCMCChains import StatsBase: predict -export InferenceAlgorithm, - Hamiltonian, +export Hamiltonian, StaticHamiltonian, AdaptiveHamiltonian, MH, @@ -71,15 +69,16 @@ export InferenceAlgorithm, RepeatSampler, Prior, predict, - externalsampler + externalsampler, + init_strategy, + loadstate -############################################### -# Abstract interface for inference algorithms # -############################################### - -const TURING_CHAIN_TYPE = MCMCChains.Chains +######################################### +# Generic AbstractMCMC methods dispatch # +######################################### -include("algorithm.jl") +const DEFAULT_CHAIN_TYPE = MCMCChains.Chains +include("abstractmcmc.jl") #################### # Sampler wrappers # @@ -312,8 +311,8 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. function AbstractMCMC.bundle_samples( ts::Vector{<:Transition}, - model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, + model::DynamicPPL.Model, + spl::AbstractSampler, state, chain_type::Type{MCMCChains.Chains}; save_state=false, @@ -374,8 +373,8 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Transition}, - model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, + model::DynamicPPL.Model, + spl::AbstractSampler, state, chain_type::Type{Vector{NamedTuple}}; kwargs..., @@ -416,7 +415,7 @@ function group_varnames_by_symbol(vns) return d end -function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples) +function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples) nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples)) return setinfo(c, merge(nt, c.info)) end @@ -435,18 +434,12 @@ include("sghmc.jl") include("emcee.jl") include("prior.jl") -################################################# -# Generic AbstractMCMC methods dispatch # -################################################# - -include("abstractmcmc.jl") - ################ # Typing tools # ################ function DynamicPPL.get_matching_type( - spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV} + spl::Union{PG,SMC}, vi, ::Type{TV} ) where {T,N,TV<:Array{T,N}} return Array{T,N} end diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 00363cde6..ba9553200 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model) new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true) end -function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) +function _check_model(model::DynamicPPL.Model, ::AbstractSampler) return _check_model(model) end +""" + Turing.Inference.init_strategy(spl::AbstractSampler) + +Get the default initialization strategy for a given sampler `spl`, i.e. how initial +parameters for sampling are chosen if not specified by the user. By default, this is +`InitFromPrior()`, which samples initial parameters from the prior distribution. +""" +init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior() + +""" + _convert_initial_params(initial_params) + +Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or +throw a useful error message. +""" +_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params +function _convert_initial_params(nt::NamedTuple) + @info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead." + return DynamicPPL.InitFromParams(nt) +end +function _convert_initial_params(d::AbstractDict{<:VarName}) + @info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead." + return DynamicPPL.InitFromParams(d) +end +function _convert_initial_params(::AbstractVector{<:Real}) + errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code." + throw(ArgumentError(errmsg)) +end +function _convert_initial_params(@nospecialize(_::Any)) + errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`." + throw(ArgumentError(errmsg)) +end + +""" + default_varinfo(rng, model, sampler) + +Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns a NTVarInfo (i.e. 'typed varinfo'). +""" +function default_varinfo( + rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler +) + # Note that in `AbstractMCMC.step`, the values in the varinfo returned here are + # immediately overwritten by a subsequent call to `init!!`. The reason why we + # _do_ create a varinfo with parameters here (as opposed to simply returning + # an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty + # typed VarInfo would fail. This can happen if two VarNames have different types + # but share the same symbol (e.g. `x.a` and `x.b`). + # TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments + # and return an empty VarInfo instead. + return DynamicPPL.typed_varinfo(VarInfo(rng, model)) +end + ######################################### # Default definitions for the interface # ######################################### -const DEFAULT_CHAIN_TYPE = MCMCChains.Chains - function AbstractMCMC.sample( - model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs... + model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs... ) - return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...) + return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...) end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, - alg::InferenceAlgorithm, + model::DynamicPPL.Model, + spl::AbstractSampler, N::Integer; + initial_params=init_strategy(spl), check_model::Bool=true, chain_type=DEFAULT_CHAIN_TYPE, kwargs..., ) - check_model && _check_model(model, alg) - return AbstractMCMC.sample(rng, model, Sampler(alg), N; chain_type, kwargs...) + check_model && _check_model(model, spl) + return AbstractMCMC.mcmcsample( + rng, + model, + spl, + N; + initial_params=_convert_initial_params(initial_params), + chain_type, + kwargs..., + ) end function AbstractMCMC.sample( - model::AbstractModel, - alg::InferenceAlgorithm, + model::DynamicPPL.Model, + alg::AbstractSampler, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; @@ -47,18 +107,66 @@ function AbstractMCMC.sample( end function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - alg::InferenceAlgorithm, + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractSampler, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; chain_type=DEFAULT_CHAIN_TYPE, check_model::Bool=true, + initial_params=fill(init_strategy(spl), n_chains), kwargs..., ) - check_model && _check_model(model, alg) - return AbstractMCMC.sample( - rng, model, Sampler(alg), ensemble, N, n_chains; chain_type, kwargs... + check_model && _check_model(model, spl) + if !(initial_params isa AbstractVector) || length(initial_params) != n_chains + errmsg = "`initial_params` must be an AbstractVector of length `n_chains`; one element per chain" + throw(ArgumentError(errmsg)) + end + return AbstractMCMC.mcmcsample( + rng, + model, + spl, + ensemble, + N, + n_chains; + chain_type, + initial_params=map(_convert_initial_params, initial_params), + kwargs..., ) end + +function loadstate(chain::MCMCChains.Chains) + if !haskey(chain.info, :samplerstate) + throw( + ArgumentError( + "the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`", + ), + ) + end + return chain.info[:samplerstate] +end + +# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures +function initialstep end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractSampler; + initial_params, + kwargs..., +) + # Generate the default varinfo. Note that any parameters inside this varinfo + # will be immediately overwritten by the next call to `init!!`. + vi = default_varinfo(rng, model, spl) + + # Fill it with initial parameters. Note that, if `InitFromParams` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = DynamicPPL.init!!(rng, model, vi, initial_params) + + # Call the actual function that does the first step. + return initialstep(rng, model, spl, vi; initial_params, kwargs...) +end diff --git a/src/mcmc/algorithm.jl b/src/mcmc/algorithm.jl deleted file mode 100644 index 725b6afbf..000000000 --- a/src/mcmc/algorithm.jl +++ /dev/null @@ -1,16 +0,0 @@ -""" - InferenceAlgorithm - -Abstract type representing an inference algorithm in Turing. Note that this is -not the same as an `AbstractSampler`: the latter is what defines the necessary -methods for actually sampling. - -To create an `AbstractSampler`, the `InferenceAlgorithm` needs to be wrapped in -`DynamicPPL.Sampler`. If `sample()` is called with an `InferenceAlgorithm`, -this wrapping occurs automatically. -""" -abstract type InferenceAlgorithm end - -function DynamicPPL.init_strategy(sampler::Sampler{<:InferenceAlgorithm}) - return DynamicPPL.InitFromPrior() -end diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index de20bc6d3..226536aca 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -13,7 +13,7 @@ Foreman-Mackey, D., Hogg, D. W., Lang, D., & Goodman, J. (2013). emcee: The MCMC Hammer. Publications of the Astronomical Society of the Pacific, 125 (925), 306. https://doi.org/10.1086/670067 """ -struct Emcee{E<:AMH.Ensemble} <: InferenceAlgorithm +struct Emcee{E<:AMH.Ensemble} <: AbstractSampler ensemble::E end @@ -33,23 +33,20 @@ end # Utility function to tetrieve the number of walkers _get_n_walkers(e::Emcee) = e.ensemble.n_walkers -_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg) # Because Emcee expects n_walkers initialisations, we need to override this -function DynamicPPL.init_strategy(spl::Sampler{<:Emcee}) +function Turing.Inference.init_strategy(spl::Emcee) return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl)) end -# TODO(penelopeysm / DPPL 0.38) This is type piracy (!!!) The function -# `_convert_initial_params` will be moved to Turing soon, and this piracy SHOULD be removed -# in https://github.com/TuringLang/Turing.jl/pull/2689, PLEASE make sure it is! -function DynamicPPL._convert_initial_params( +# We also have to explicitly allow this or else it will error... +function Turing.Inference._convert_initial_params( x::AbstractVector{<:DynamicPPL.AbstractInitStrategy} ) return x end function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler{<:Emcee}; initial_params, kwargs... + rng::Random.AbstractRNG, model::Model, spl::Emcee; initial_params, kwargs... ) # Sample from the prior n = _get_n_walkers(spl) @@ -83,7 +80,7 @@ function AbstractMCMC.step( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Emcee}, state::EmceeState; kwargs... + rng::AbstractRNG, model::Model, spl::Emcee, state::EmceeState; kwargs... ) # Generate a log joint function. vi = state.vi @@ -95,7 +92,7 @@ function AbstractMCMC.step( ) # Compute the next states. - t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states) + t, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state @@ -110,7 +107,7 @@ end function AbstractMCMC.bundle_samples( samples::Vector{<:Vector}, model::AbstractModel, - spl::Sampler{<:Emcee}, + spl::Emcee, state::EmceeState, chain_type::Type{MCMCChains.Chains}; save_state=false, diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index d89d25cf9..18dbfa417 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -20,11 +20,11 @@ Mean │ 1 │ m │ 0.824853 │ ``` """ -struct ESS <: InferenceAlgorithm end +struct ESS <: AbstractSampler end # always accept in the first step -function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, ::ESS, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) @@ -35,7 +35,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, ::ESS, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] @@ -103,3 +103,18 @@ struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} end (ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) + +# Needed for method ambiguity resolution, even though this method is never going to be +# called in practice. This just shuts Aqua up. +# TODO(penelopeysm): Remove this when the default `step(rng, ::DynamicPPL.Model, +# ::AbstractSampler) method in `src/mcmc/abstractmcmc.jl` is removed. +function AbstractMCMC.step( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::EllipticalSliceSampling.ESS; + kwargs..., +) + return error( + "This method is not implemented! If you want to use the ESS sampler in Turing.jl, please use `Turing.ESS()` instead. If you want the default behaviour in EllipticalSliceSampling.jl, wrap your model in a different subtype of `AbstractMCMC.AbstractModel`, and then implement the necessary EllipticalSliceSampling.jl methods on it.", + ) +end diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 0755e4160..58a336be4 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -1,7 +1,8 @@ """ ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} -Represents a sampler that is not an implementation of `InferenceAlgorithm`. +Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng, +::DynamicPPL.Model, spl)`. The `Unconstrained` type-parameter is to indicate whether the sampler requires unconstrained space. @@ -10,25 +11,49 @@ $(TYPEDFIELDS) # Turing.jl's interface for external samplers -When implementing a new `MySampler <: AbstractSampler`, -`MySampler` must first and foremost conform to the `AbstractMCMC` interface to work with Turing.jl's `externalsampler` function. -In particular, it must implement: +If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl +models, there are two options: -- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl) -- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the parameters from the transition returned by your sampler (i.e., the first return value of `step`). - There is a default implementation for this method, which is to return `external_transition.θ`. +1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the + most powerful option and is what Turing.jl's in-house samplers do. Implementing this + means that you can directly call `sample(model, MySampler(), N)`. + +2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This + struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step` + implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use + this with Turing.jl, you will need to wrap your sampler: `sample(model, + externalsampler(MySampler()), N)`. + +This section describes the latter. + +`MySampler` must implement the following methods: + +- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is + documented in AbstractMCMC.jl) +- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the + parameters from the transition returned by your sampler (i.e., the first return value of + `step`). There is a default implementation for this method, which is to return + `external_transition.θ`. !!! note - In a future breaking release of Turing, this is likely to change to `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. `Turing.Inference.getparams` is technically an internal method, so the aim here is to unify the interface for samplers at a higher level. + In a future breaking release of Turing, this is likely to change to + `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. + `Turing.Inference.getparams` is technically an internal method, so the aim here is to + unify the interface for samplers at a higher level. -There are a few more optional functions which you can implement to improve the integration with Turing.jl: +There are a few more optional functions which you can implement to improve the integration +with Turing.jl: -- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`. +- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as + a component in Turing's Gibbs sampler, you should make this evaluate to `true`. -- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space. +- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires + unconstrained space, you should return `true`. This tells Turing to perform linking on the + VarInfo before evaluation, and ensures that the parameter values passed to your sampler + will always be in unconstrained (Euclidean) space. """ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: - InferenceAlgorithm + AbstractSampler "the sampler to wrap" sampler::S "the automatic differentiation (AD) backend to use" @@ -115,19 +140,18 @@ getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.pa function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler_wrapper::Sampler{<:ExternalSampler}; + sampler_wrapper::ExternalSampler; initial_state=nothing, - initial_params=DynamicPPL.init_strategy(sampler_wrapper.alg.sampler), + initial_params=DynamicPPL.init_strategy(sampler_wrapper.sampler), kwargs..., ) - alg = sampler_wrapper.alg - sampler = alg.sampler + sampler = sampler_wrapper.sampler # Initialise varinfo with initial params and link the varinfo if needed. varinfo = DynamicPPL.VarInfo(model) _, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params) - if requires_unconstrained_space(alg) + if requires_unconstrained_space(sampler_wrapper) varinfo = DynamicPPL.link(varinfo, model) end @@ -138,7 +162,7 @@ function AbstractMCMC.step( # Construct LogDensityFunction f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype + model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype ) # Then just call `AbstractMCMC.step` with the right arguments. @@ -174,11 +198,11 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler_wrapper::Sampler{<:ExternalSampler}, + sampler_wrapper::ExternalSampler, state::TuringState; kwargs..., ) - sampler = sampler_wrapper.alg.sampler + sampler = sampler_wrapper.sampler f = state.ldf # Then just call `AdvancedMCMC.step` with the right arguments. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index f8e8e1393..7d15829a3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,12 +1,11 @@ """ - isgibbscomponent(alg::Union{InferenceAlgorithm, AbstractMCMC.AbstractSampler}) + isgibbscomponent(spl::AbstractSampler) -Return a boolean indicating whether `alg` is a valid component for a Gibbs sampler. +Return a boolean indicating whether `spl` is a valid component for a Gibbs sampler. Defaults to `false` if no method has been defined for a particular algorithm type. """ -isgibbscomponent(::InferenceAlgorithm) = false -isgibbscomponent(spl::Sampler) = isgibbscomponent(spl.alg) +isgibbscomponent(::AbstractSampler) = false isgibbscomponent(::ESS) = true isgibbscomponent(::HMC) = true @@ -237,9 +236,6 @@ function make_conditional( return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner end -wrap_in_sampler(x::AbstractMCMC.AbstractSampler) = x -wrap_in_sampler(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - to_varname(x::VarName) = x to_varname(x::Symbol) = VarName{x}() to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] @@ -269,10 +265,8 @@ Gibbs((@varname(x), :y) => NUTS(), :z => MH()) # Fields $(TYPEDFIELDS) """ -struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: - InferenceAlgorithm - # TODO(mhauru) Revisit whether A should have a fixed element type once - # InferenceAlgorithm/Sampler types have been cleaned up. +struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: AbstractSampler + # TODO(mhauru) Revisit whether A should have a fixed element type. "varnames representing variables for each sampler" varnames::V "samplers for each entry in `varnames`" @@ -290,7 +284,7 @@ struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: end end - samplers = tuple(map(wrap_in_sampler, samplers)...) + samplers = tuple(samplers...) varnames = tuple(map(to_varname_list, varnames)...) return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers) end @@ -315,7 +309,7 @@ support calling both step and step_warmup as the initial step. DynamicPPL initia incompatible with step_warmup. """ function initial_varinfo(rng, model, spl, initial_params::DynamicPPL.AbstractInitStrategy) - vi = DynamicPPL.default_varinfo(rng, model, spl) + vi = Turing.Inference.default_varinfo(rng, model, spl) _, vi = DynamicPPL.init!!(rng, model, vi, initial_params) return vi end @@ -323,13 +317,12 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl), + spl::Gibbs; + initial_params=Turing.Inference.init_strategy(spl), kwargs..., ) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) vi, states = gibbs_initialstep_recursive( @@ -348,13 +341,12 @@ end function AbstractMCMC.step_warmup( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl), + spl::Gibbs; + initial_params=Turing.Inference.init_strategy(spl), kwargs..., ) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) vi, states = gibbs_initialstep_recursive( @@ -434,14 +426,13 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, + spl::Gibbs, state::GibbsState; kwargs..., ) vi = get_varinfo(state) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) @@ -454,14 +445,13 @@ end function AbstractMCMC.step_warmup( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, + spl::Gibbs, state::GibbsState; kwargs..., ) vi = get_varinfo(state) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) @@ -472,7 +462,7 @@ function AbstractMCMC.step_warmup( end """ - setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) + setparams_varinfo!!(model, sampler::AbstractSampler, state, params::AbstractVarInfo) A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an `AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to @@ -481,12 +471,14 @@ A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameter `model` is typically a `DynamicPPL.Model`, but can also be e.g. an `AbstractMCMC.LogDensityModel`. """ -function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) +function setparams_varinfo!!( + model::DynamicPPL.Model, ::AbstractSampler, state, params::AbstractVarInfo +) return AbstractMCMC.setparams!!(model, state, params[:]) end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::MHState, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::MH, state::MHState, params::AbstractVarInfo ) # Re-evaluate to update the logprob. new_vi = last(DynamicPPL.evaluate!!(model, params)) @@ -494,10 +486,7 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:ESS}, - state::AbstractVarInfo, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::ESS, state::AbstractVarInfo, params::AbstractVarInfo ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. @@ -506,24 +495,21 @@ end function setparams_varinfo!!( model::DynamicPPL.Model, - sampler::Sampler{<:ExternalSampler}, + sampler::ExternalSampler, state::TuringState, params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.adtype ) - new_inner_state = setparams_varinfo!!( - AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params + new_inner_state = AbstractMCMC.setparams!!( + AbstractMCMC.LogDensityModel(logdensity), state.state, params[:] ) return TuringState(new_inner_state, params, logdensity) end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:Hamiltonian}, - state::HMCState, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::Hamiltonian, state::HMCState, params::AbstractVarInfo ) θ_new = params[:] hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) @@ -537,7 +523,7 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler{<:PG}, state::PGState, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::PG, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 636dc2e84..101847b75 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -1,4 +1,4 @@ -abstract type Hamiltonian <: InferenceAlgorithm end +abstract type Hamiltonian <: AbstractSampler end abstract type StaticHamiltonian <: Hamiltonian end abstract type AdaptiveHamiltonian <: Hamiltonian end @@ -80,23 +80,25 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT; adtype=adtype) end -DynamicPPL.init_strategy(::Sampler{<:Hamiltonian}) = DynamicPPL.InitFromUniform() +Turing.Inference.init_strategy(::Hamiltonian) = DynamicPPL.InitFromUniform() # Handle setting `nadapts` and `discard_initial` function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::Sampler{<:AdaptiveHamiltonian}, + sampler::AdaptiveHamiltonian, N::Integer; - chain_type=TURING_CHAIN_TYPE, - initial_params=DynamicPPL.init_strategy(sampler), + check_model=true, + chain_type=DEFAULT_CHAIN_TYPE, + initial_params=Turing.Inference.init_strategy(sampler), initial_state=nothing, progress=PROGRESS[], - nadapts=sampler.alg.n_adapts, + nadapts=sampler.n_adapts, discard_adapt=true, discard_initial=-1, kwargs..., ) + check_model && _check_model(model, sampler) if initial_state === nothing # If `nadapts` is `-1`, then the user called a convenience # constructor like `NUTS()` or `NUTS(0.65)`, @@ -173,10 +175,10 @@ function find_initial_params( ) end -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:Hamiltonian}, + model::DynamicPPL.Model, + spl::Hamiltonian, vi_original::AbstractVarInfo; # the initial_params kwarg is always passed on from sample(), cf. DynamicPPL # src/sampler.jl, so we don't need to provide a default value here @@ -192,10 +194,10 @@ function DynamicPPL.initialstep( theta = vi[:] # Create a Hamiltonian. - metricT = getmetricT(spl.alg) + metricT = getmetricT(spl) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -213,15 +215,15 @@ function DynamicPPL.initialstep( theta = vi[:] # Find good eps if not provided one - if iszero(spl.alg.ϵ) + if iszero(spl.ϵ) ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta) verbose && @info "Found initial step size" ϵ else - ϵ = spl.alg.ϵ + ϵ = spl.ϵ end # Generate a kernel and adaptor. - kernel = make_ahmc_kernel(spl.alg, ϵ) - adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ) + kernel = make_ahmc_kernel(spl, ϵ) + adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ) transition = Transition(model, vi, NamedTuple()) state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor) @@ -231,8 +233,8 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:Hamiltonian}, + model::DynamicPPL.Model, + spl::Hamiltonian, state::HMCState; nadapts=0, kwargs..., @@ -247,7 +249,7 @@ function AbstractMCMC.step( # Adaptation i = state.i + 1 - if spl.alg isa AdaptiveHamiltonian + if spl isa AdaptiveHamiltonian hamiltonian, kernel, _ = AHMC.adapt!( hamiltonian, state.kernel, @@ -277,7 +279,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -442,17 +444,17 @@ end ##### HMC core functions ##### -getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ -getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor) +getstepsize(sampler::Hamiltonian, state) = sampler.ϵ +getstepsize(sampler::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor) function getstepsize( - sampler::Sampler{<:AdaptiveHamiltonian}, + sampler::AdaptiveHamiltonian, state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation}, ) where {TV,TKernel,THam,PhType} return state.kernel.τ.integrator.ϵ end -gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim) -function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) +gen_metric(dim::Int, spl::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim) +function gen_metric(dim::Int, spl::AdaptiveHamiltonian, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 932e6e0f4..88f915d1f 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -24,16 +24,16 @@ end sample(gdemo([1.5, 2]), IS(), 1000) ``` """ -struct IS <: InferenceAlgorithm end +struct IS <: AbstractSampler end -function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... +function Turing.Inference.initialstep( + rng::AbstractRNG, model::Model, spl::IS, vi::AbstractVarInfo; kwargs... ) return Transition(model, vi, nothing), nothing end function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... + rng::Random.AbstractRNG, model::Model, spl::IS, ::Nothing; kwargs... ) model = DynamicPPL.setleafcontext(model, ISContext(rng)) _, vi = DynamicPPL.evaluate!!(model, DynamicPPL.VarInfo()) @@ -42,7 +42,7 @@ function AbstractMCMC.step( end # Calculate evidence. -function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) +function getlogevidence(samples::Vector{<:Transition}, ::IS, state) return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 2ccceb3d7..833303b86 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -104,7 +104,7 @@ mean(chain) ``` """ -struct MH{P} <: InferenceAlgorithm +struct MH{P} <: AbstractSampler proposals::P function MH(proposals...) @@ -247,16 +247,16 @@ function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::Abst end """ - dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo) + dist_val_tuple(spl::MH, vi::VarInfo) Return two `NamedTuples`. The first `NamedTuple` has symbols as keys and distributions as values. The second `NamedTuple` has model symbols as keys and their stored values as values. """ -function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) +function dist_val_tuple(spl::MH, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) vns = all_varnames_grouped_by_symbol(vi) - dt = _dist_tuple(spl.alg.proposals, vi, vns) + dt = _dist_tuple(spl.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt end @@ -324,9 +324,7 @@ function maybe_link!!(varinfo, sampler, proposal, model) end # Make a proposal if we don't have a covariance proposal matrix (the default). -function propose!!( - rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal -) +function propose!!(rng::AbstractRNG, prev_state::MHState, model::Model, spl::MH, proposal) vi = prev_state.varinfo # Retrieve distribution and value NamedTuples. dt, vt = dist_val_tuple(spl, vi) @@ -358,7 +356,7 @@ function propose!!( rng::AbstractRNG, prev_state::MHState, model::Model, - spl::Sampler{<:MH}, + spl::MH, proposal::AdvancedMH.RandomWalkProposal, ) vi = prev_state.varinfo @@ -367,7 +365,7 @@ function propose!!( vals = vi[:] # Create a sampler and the previous transition. - mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) + mh_sampler = AMH.MetropolisHastings(spl.proposals) prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. @@ -388,27 +386,23 @@ function propose!!( return MHState(vi, trans.lp) end -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:MH}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, spl::MH, vi::AbstractVarInfo; kwargs... ) # If we're doing random walk with a covariance matrix, # just link everything before sampling. - vi = maybe_link!!(vi, spl, spl.alg.proposals, model) + vi = maybe_link!!(vi, spl, spl.proposals, model) return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, spl::MH, state::MHState; kwargs... ) # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space - new_state = propose!!(rng, state, model, spl, spl.alg.proposals) + new_state = propose!!(rng, state, model, spl, spl.proposals) return Transition(model, new_state.varinfo, nothing), new_state end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index e72d86e62..7aadef09e 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -56,7 +56,7 @@ function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) ) end -abstract type ParticleInference <: InferenceAlgorithm end +abstract type ParticleInference <: AbstractSampler end #### #### Generic Sequential Monte Carlo sampler. @@ -101,20 +101,22 @@ struct SMCState{P,F<:AbstractFloat} average_logevidence::F end -function getlogevidence(samples, sampler::Sampler{<:SMC}, state::SMCState) +function getlogevidence(samples, ::SMC, state::SMCState) return state.average_logevidence end function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::Sampler{<:SMC}, + sampler::SMC, N::Integer; - chain_type=TURING_CHAIN_TYPE, - initial_params=DynamicPPL.init_strategy(sampler), + check_model=true, + chain_type=DEFAULT_CHAIN_TYPE, + initial_params=Turing.Inference.init_strategy(sampler), progress=PROGRESS[], kwargs..., ) + check_model && _check_model(model, sampler) # need to add on the `nparticles` keyword argument for `initialstep` to make use of return AbstractMCMC.mcmcsample( rng, @@ -129,10 +131,10 @@ function AbstractMCMC.sample( ) end -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:SMC}, + model::DynamicPPL.Model, + spl::SMC, vi::AbstractVarInfo; nparticles::Int, kwargs..., @@ -149,7 +151,7 @@ function DynamicPPL.initialstep( ) # Perform particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl) # Extract the first particle and its weight. particle = particles.vals[1] @@ -164,7 +166,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - ::AbstractRNG, model::AbstractModel, spl::Sampler{<:SMC}, state::SMCState; kwargs... + ::AbstractRNG, model::DynamicPPL.Model, spl::SMC, state::SMCState; kwargs... ) # Extract the index of the current particle. index = state.particleindex @@ -238,9 +240,7 @@ end get_varinfo(state::PGState) = state.vi function getlogevidence( - transitions::AbstractVector{<:Turing.Inference.Transition}, - sampler::Sampler{<:PG}, - state::PGState, + transitions::AbstractVector{<:Turing.Inference.Transition}, ::PG, ::PGState ) logevidences = map(transitions) do t if haskey(t.stat, :logevidence) @@ -254,17 +254,13 @@ function getlogevidence( return mean(logevidences) end -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:PG}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs... ) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create a new set of particles - num_particles = spl.alg.nparticles + num_particles = spl.nparticles particles = AdvancedPS.ParticleContainer( [ AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for @@ -275,7 +271,7 @@ function DynamicPPL.initialstep( ) # Perform a particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl) # Pick a particle to be retained. Ws = AdvancedPS.getweights(particles) @@ -290,7 +286,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::AbstractModel, spl::Sampler{<:PG}, state::PGState; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, state::PGState; kwargs... ) # Reset the VarInfo before new sweep. vi = state.vi @@ -300,7 +296,7 @@ function AbstractMCMC.step( reference = AdvancedPS.forkr(AdvancedPS.Trace(model, vi, state.rng, false)) # Create a new set of particles. - num_particles = spl.alg.nparticles + num_particles = spl.nparticles x = map(1:num_particles) do i if i != num_particles return AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) @@ -311,7 +307,7 @@ function AbstractMCMC.step( particles = AdvancedPS.ParticleContainer(x, AdvancedPS.TracedRNG(), rng) # Perform a particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl, reference) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl, reference) # Pick a particle to be retained. Ws = AdvancedPS.getweights(particles) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c5228d8fc..c4ec6c6f3 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -3,12 +3,12 @@ Algorithm for sampling from the prior. """ -struct Prior <: InferenceAlgorithm end +struct Prior <: AbstractSampler end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:Prior}, + sampler::Prior, state=nothing; kwargs..., ) diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index 5669a27b5..133517494 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -24,11 +24,12 @@ struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSa end end -function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) - return RepeatSampler(Sampler(alg), num_repeat) -end - -function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::RepeatSampler, + state, + params::DynamicPPL.AbstractVarInfo, +) return setparams_varinfo!!(model, sampler.sampler, state, params) end @@ -40,6 +41,14 @@ function AbstractMCMC.step( ) return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) end +# The following method needed for method ambiguity resolution. +# TODO(penelopeysm): Remove this method once the default `AbstractMCMC.step(rng, +# ::DynamicPPL.Model, ::AbstractSampler)` method in `src/mcmc/abstractmcmc.jl` is removed. +function AbstractMCMC.step( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::RepeatSampler; kwargs... +) + return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) +end function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -85,26 +94,28 @@ end # Need some extra leg work to make RepeatSampler work seamlessly with DynamicPPL models + # samplers, instead of generic AbstractMCMC samplers. -function DynamicPPL.init_strategy(spl::RepeatSampler{<:Sampler}) - return DynamicPPL.init_strategy(spl.sampler) +function Turing.Inference.init_strategy(spl::RepeatSampler) + return Turing.Inference.init_strategy(spl.sampler) end function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::RepeatSampler{<:Sampler}, + sampler::RepeatSampler, N::Integer; - initial_params=DynamicPPL.init_strategy(sampler), - chain_type=TURING_CHAIN_TYPE, + check_model=true, + initial_params=Turing.Inference.init_strategy(sampler), + chain_type=DEFAULT_CHAIN_TYPE, progress=PROGRESS[], kwargs..., ) + check_model && _check_model(model, sampler) return AbstractMCMC.mcmcsample( rng, model, sampler, N; - initial_params=initial_params, + initial_params=_convert_initial_params(initial_params), chain_type=chain_type, progress=progress, kwargs..., @@ -114,15 +125,17 @@ end function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::RepeatSampler{<:Sampler}, + sampler::RepeatSampler, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; - initial_params=fill(DynamicPPL.init_strategy(sampler), n_chains), - chain_type=TURING_CHAIN_TYPE, + check_model=true, + initial_params=fill(Turing.Inference.init_strategy(sampler), n_chains), + chain_type=DEFAULT_CHAIN_TYPE, progress=PROGRESS[], kwargs..., ) + check_model && _check_model(model, sampler) return AbstractMCMC.mcmcsample( rng, model, @@ -130,7 +143,7 @@ function AbstractMCMC.sample( ensemble, N, n_chains; - initial_params=initial_params, + initial_params=map(_convert_initial_params, initial_params), chain_type=chain_type, progress=progress, kwargs..., diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index a14edbc27..267a21620 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -51,12 +51,8 @@ struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} velocity::T end -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGHMC}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::Random.AbstractRNG, model::Model, spl::SGHMC, vi::AbstractVarInfo; kwargs... ) # Transform the samples to unconstrained space. if !DynamicPPL.is_transformed(vi) @@ -66,7 +62,7 @@ function DynamicPPL.initialstep( # Compute initial sample and state. sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) state = SGHMCState(ℓ, vi, zero(vi[:])) @@ -74,11 +70,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGHMC}, - state::SGHMCState; - kwargs..., + rng::Random.AbstractRNG, model::Model, spl::SGHMC, state::SGHMCState; kwargs... ) # Compute gradient of log density. ℓ = state.logdensity @@ -90,8 +82,8 @@ function AbstractMCMC.step( # equation (15) of Chen et al. (2014) v = state.velocity θ .+= v - η = spl.alg.learning_rate - α = spl.alg.momentum_decay + η = spl.learning_rate + α = spl.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) # Save new variables. @@ -190,12 +182,8 @@ struct SGLDState{L,V<:AbstractVarInfo} step::Int end -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGLD}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::Random.AbstractRNG, model::Model, spl::SGLD, vi::AbstractVarInfo; kwargs... ) # Transform the samples to unconstrained space. if !DynamicPPL.is_transformed(vi) @@ -203,9 +191,9 @@ function DynamicPPL.initialstep( end # Create first sample and state. - transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0)))) + transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) state = SGLDState(ℓ, vi, 1) @@ -213,7 +201,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler{<:SGLD}, state::SGLDState; kwargs... + rng::Random.AbstractRNG, model::Model, spl::SGLD, state::SGLDState; kwargs... ) # Perform gradient step. ℓ = state.logdensity @@ -221,7 +209,7 @@ function AbstractMCMC.step( θ = vi[:] grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step - stepsize = spl.alg.stepsize(step) + stepsize = spl.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) # Save new variables. diff --git a/test/essential/container.jl b/test/essential/container.jl index 100cf0432..19609b6b5 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -2,7 +2,7 @@ module ContainerTests using AdvancedPS: AdvancedPS using Distributions: Bernoulli, Beta, Gamma, Normal -using DynamicPPL: DynamicPPL, @model, Sampler +using DynamicPPL: DynamicPPL, @model using Test: @test, @testset using Turing @@ -20,7 +20,7 @@ using Turing @testset "constructor" begin vi = DynamicPPL.VarInfo() vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) - sampler = Sampler(PG(10)) + sampler = PG(10) model = test() trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) @@ -45,7 +45,7 @@ using Turing end vi = DynamicPPL.VarInfo() vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) - sampler = Sampler(PG(10)) + sampler = PG(10) model = normal() trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) diff --git a/test/ext/dynamichmc.jl b/test/ext/dynamichmc.jl index 3f609504d..004970dd3 100644 --- a/test/ext/dynamichmc.jl +++ b/test/ext/dynamichmc.jl @@ -6,7 +6,6 @@ using Test: @test, @testset using Distributions: sample using DynamicHMC: DynamicHMC using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using Turing diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index ce1d532c3..6918eaddf 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -6,7 +6,6 @@ using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample using AbstractMCMC: AbstractMCMC import DynamicPPL -using DynamicPPL: Sampler import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -70,18 +69,12 @@ using Turing end @testset "save/resume correctly reloads state" begin - struct StaticSampler <: Turing.Inference.InferenceAlgorithm end - function DynamicPPL.initialstep( - rng, model, ::DynamicPPL.Sampler{<:StaticSampler}, vi; kwargs... - ) + struct StaticSampler <: AbstractMCMC.AbstractSampler end + function Turing.Inference.initialstep(rng, model, ::StaticSampler, vi; kwargs...) return Turing.Inference.Transition(model, vi, nothing), vi end function AbstractMCMC.step( - rng, - model, - ::DynamicPPL.Sampler{<:StaticSampler}, - vi::DynamicPPL.AbstractVarInfo; - kwargs..., + rng, model, ::StaticSampler, vi::DynamicPPL.AbstractVarInfo; kwargs... ) return Turing.Inference.Transition(model, vi, nothing), vi end @@ -91,7 +84,7 @@ using Turing @testset "single-chain" begin chn1 = sample(demo(), StaticSampler(), 10; save_state=true) @test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo - chn2 = sample(demo(), StaticSampler(), 10; initial_state=chn1.info.samplerstate) + chn2 = sample(demo(), StaticSampler(), 10; initial_state=loadstate(chn1)) xval = chn1[:x][1] @test all(chn2[:x] .== xval) end @@ -108,7 +101,7 @@ using Turing MCMCThreads(), 10, nchains; - initial_state=chn1.info.samplerstate, + initial_state=loadstate(chn1), ) xval = chn1[:x][1, :] @test all(i -> chn2[:x][i, :] == xval, 1:10) @@ -124,20 +117,12 @@ using Turing check_gdemo(chn1) chn1_contd = sample( - StableRNG(seed), - gdemo_default, - alg1, - 2_000; - initial_state=chn1.info.samplerstate, + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) ) check_gdemo(chn1_contd) chn1_contd2 = sample( - StableRNG(seed), - gdemo_default, - alg1, - 2_000; - initial_state=chn1.info.samplerstate, + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) ) check_gdemo(chn1_contd2) @@ -152,11 +137,7 @@ using Turing check_gdemo(chn2) chn2_contd = sample( - StableRNG(seed), - gdemo_default, - alg2, - 2_000; - initial_state=chn2.info.samplerstate, + StableRNG(seed), gdemo_default, alg2, 2_000; initial_state=loadstate(chn2) ) check_gdemo(chn2_contd) @@ -171,11 +152,7 @@ using Turing check_gdemo(chn3) chn3_contd = sample( - StableRNG(seed), - gdemo_default, - alg3, - 5_000; - initial_state=chn3.info.samplerstate, + StableRNG(seed), gdemo_default, alg3, 5_000; initial_state=loadstate(chn3) ) check_gdemo(chn3_contd) end diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl new file mode 100644 index 000000000..6f4b47613 --- /dev/null +++ b/test/mcmc/abstractmcmc.jl @@ -0,0 +1,136 @@ +module TuringAbstractMCMCTests + +using AbstractMCMC: AbstractMCMC +using DynamicPPL: DynamicPPL +using Random: AbstractRNG +using Test: @test, @testset, @test_throws +using Turing + +@testset "Initial parameters" begin + # Dummy algorithm that just returns initial value and does not perform any sampling + abstract type OnlyInit <: AbstractMCMC.AbstractSampler end + struct OnlyInitDefault <: OnlyInit end + struct OnlyInitUniform <: OnlyInit end + Turing.Inference.init_strategy(::OnlyInitUniform) = InitFromUniform() + function Turing.Inference.initialstep( + rng::AbstractRNG, + model::DynamicPPL.Model, + ::OnlyInit, + vi::DynamicPPL.VarInfo=DynamicPPL.VarInfo(rng, model); + kwargs..., + ) + return vi, nothing + end + + @testset "init_strategy" begin + # check that the default init strategy is prior + @test Turing.Inference.init_strategy(OnlyInitDefault()) == InitFromPrior() + @test Turing.Inference.init_strategy(OnlyInitUniform()) == InitFromUniform() + end + + for spl in (OnlyInitDefault(), OnlyInitUniform()) + # model with one variable: initialization p = 0.2 + @model function coinflip() + p ~ Beta(1, 1) + return 10 ~ Binomial(25, p) + end + model = coinflip() + lptrue = logpdf(Binomial(25, 0.2), 10) + let inits = InitFromParams((; p=0.2)) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test chain[1].metadata.p.vals == [0.2] + @test DynamicPPL.getlogjoint(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.p.vals == [0.2] + @test DynamicPPL.getlogjoint(c[1]) == lptrue + end + end + + # check that Vector no longer works + @test_throws ArgumentError sample( + model, spl, 1; initial_params=[4, -1], progress=false + ) + @test_throws ArgumentError sample( + model, spl, 1; initial_params=[missing, -1], progress=false + ) + + # model with two variables: initialization s = 4, m = -1 + @model function twovars() + s ~ InverseGamma(2, 3) + return m ~ Normal(0, sqrt(s)) + end + model = twovars() + lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) + for inits in ( + InitFromParams((s=4, m=-1)), + (s=4, m=-1), + InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)), + Dict(@varname(s) => 4, @varname(m) => -1), + ) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test chain[1].metadata.s.vals == [4] + @test chain[1].metadata.m.vals == [-1] + @test DynamicPPL.getlogjoint(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.s.vals == [4] + @test c[1].metadata.m.vals == [-1] + @test DynamicPPL.getlogjoint(c[1]) == lptrue + end + end + + # set only m = -1 + for inits in ( + InitFromParams((; s=missing, m=-1)), + InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)), + (; s=missing, m=-1), + Dict(@varname(s) => missing, @varname(m) => -1), + InitFromParams((; m=-1)), + InitFromParams(Dict(@varname(m) => -1)), + (; m=-1), + Dict(@varname(m) => -1), + ) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test !ismissing(chain[1].metadata.s.vals[1]) + @test chain[1].metadata.m.vals == [-1] + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test !ismissing(c[1].metadata.s.vals[1]) + @test c[1].metadata.m.vals == [-1] + end + end + end +end + +end # module diff --git a/test/mcmc/emcee.jl b/test/mcmc/emcee.jl index 03861f17e..44bf75858 100644 --- a/test/mcmc/emcee.jl +++ b/test/mcmc/emcee.jl @@ -4,7 +4,6 @@ using ..Models: gdemo_default using ..NumericalTests: check_gdemo using Distributions: sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using Test: @test, @test_throws, @testset using Turing diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 8c2d38b35..e497fdde3 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -2,10 +2,9 @@ module ESSTests using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default using ..NumericalTests: check_MoGtest_default, check_numerical -using ..SamplerTestUtils: test_rng_respected +using ..SamplerTestUtils: test_rng_respected, test_sampler_analytical using Distributions: Normal, sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using StableRNGs: StableRNG using Test: @test, @testset @@ -85,9 +84,9 @@ using Turing model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,) end - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( models_conditioned, - DynamicPPL.Sampler(ESS()), + ESS(), 2000; # Filter out the varnames we've conditioned on. varnames_filter=vn -> DynamicPPL.getsym(vn) != :s, diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index d1e72a94e..56c03c87a 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -1,6 +1,7 @@ module ExternalSamplerTests using ..Models: gdemo_default +using ..SamplerTestUtils: test_sampler_analytical using AbstractMCMC: AbstractMCMC using AdvancedMH: AdvancedMH using Distributions: sample @@ -205,9 +206,7 @@ end # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". sampler = initialize_nuts(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; adtype, unconstrained=true) - ) + sampler_ext = externalsampler(sampler; adtype, unconstrained=true) # TODO: AdvancedHMC samplers do not return the initial parameters as the first # step, so `test_initial_params` will fail. This should be fixed upstream in @@ -223,7 +222,7 @@ end ) @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( [model], sampler_ext, 2_000; @@ -252,14 +251,12 @@ end # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". sampler = initialize_mh_rw(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; unconstrained=true) - ) + sampler_ext = externalsampler(sampler; unconstrained=true) @testset "initial_params" begin test_initial_params(model, sampler_ext) end @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( [model], sampler_ext, 2_000; @@ -286,12 +283,12 @@ end # @testset "MH with prior proposal" begin # @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # sampler = initialize_mh_with_prior_proposal(model); - # sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; unconstrained=false)) + # sampler_ext = externalsampler(sampler; unconstrained=false) # @testset "initial_params" begin # test_initial_params(model, sampler_ext) # end # @testset "inference" begin - # DynamicPPL.TestUtils.test_sampler( + # test_sampler_analytical( # [model], # sampler_ext, # 10_000; diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 15bf2afea..1e3d5856c 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -134,26 +134,24 @@ end # Test that the samplers are being called in the correct order, on the correct target # variables. +# @testset "Sampler call order" begin # A wrapper around inference algorithms to allow intercepting the dispatch cascade to # collect testing information. - struct AlgWrapper{Alg<:Inference.InferenceAlgorithm} <: Inference.InferenceAlgorithm + struct AlgWrapper{Alg<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSampler inner::Alg end - unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = - DynamicPPL.Sampler(sampler.alg.inner) - # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. # They all just propagate the call to the inner algorithm. Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) function Inference.setparams_varinfo!!( model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, state, params::DynamicPPL.AbstractVarInfo, ) - return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params) + return Inference.setparams_varinfo!!(model, sampler.inner, state, params) end # targets_and_algs will be a list of tuples, where the first element is the target_vns @@ -175,25 +173,23 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, args...; kwargs..., ) - capture_targets_and_algs(sampler.alg.inner, model.context) - return AbstractMCMC.step(rng, model, unwrap_sampler(sampler), args...; kwargs...) + capture_targets_and_algs(sampler.inner, model.context) + return AbstractMCMC.step(rng, model, sampler.inner, args...; kwargs...) end - function DynamicPPL.initialstep( + function Turing.Inference.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, args...; kwargs..., ) - capture_targets_and_algs(sampler.alg.inner, model.context) - return DynamicPPL.initialstep( - rng, model, unwrap_sampler(sampler), args...; kwargs... - ) + capture_targets_and_algs(sampler.inner, model.context) + return Turing.Inference.initialstep(rng, model, sampler.inner, args...; kwargs...) end struct Wrapper{T<:Real} @@ -279,7 +275,7 @@ end @testset "Gibbs warmup" begin # An inference algorithm, for testing purposes, that records how many warm-up steps # and how many non-warm-up steps haven been taken. - mutable struct WarmupCounter <: Inference.InferenceAlgorithm + mutable struct WarmupCounter <: AbstractMCMC.AbstractSampler warmup_init_count::Int non_warmup_init_count::Int warmup_count::Int @@ -298,7 +294,7 @@ end Turing.Inference.get_varinfo(state::VarInfoState) = state.vi function Turing.Inference.setparams_varinfo!!( ::DynamicPPL.Model, - ::DynamicPPL.Sampler, + ::WarmupCounter, ::VarInfoState, params::DynamicPPL.AbstractVarInfo, ) @@ -306,23 +302,17 @@ end end function AbstractMCMC.step( - ::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}; - kwargs..., + ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... ) - spl.alg.non_warmup_init_count += 1 + spl.non_warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step_warmup( - ::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}; - kwargs..., + ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... ) - spl.alg.warmup_init_count += 1 + spl.warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end @@ -330,22 +320,22 @@ end function AbstractMCMC.step( ::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}, + spl::WarmupCounter, s::VarInfoState; kwargs..., ) - spl.alg.non_warmup_count += 1 + spl.non_warmup_count += 1 return Turing.Inference.Transition(model, s.vi, nothing), s end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}, + spl::WarmupCounter, s::VarInfoState; kwargs..., ) - spl.alg.warmup_count += 1 + spl.warmup_count += 1 return Turing.Inference.Transition(model, s.vi, nothing), s end @@ -486,7 +476,7 @@ end @nospecialize function AbstractMCMC.bundle_samples( samples::Vector, ::typeof(model), - ::DynamicPPL.Sampler{<:Gibbs}, + ::Gibbs, state, ::Type{MCMCChains.Chains}; kwargs..., @@ -670,14 +660,10 @@ end @testset "$sampler" for sampler in samplers # Check that taking steps performs as expected. rng = Random.default_rng() - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(sampler) - ) + transition, state = AbstractMCMC.step(rng, model, sampler) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(sampler), state - ) + transition, state = AbstractMCMC.step(rng, model, sampler, state) check_transition_varnames(transition, vns) end end @@ -747,36 +733,32 @@ end @testset "with both `s` and `m` as random" begin model = gdemo(1.5, 2.0) vns = (@varname(s), @varname(m)) - alg = Gibbs(vns => MH()) + spl = Gibbs(vns => MH()) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # `sample` Random.seed!(42) - chain = sample(model, alg, 1_000; progress=false) + chain = sample(model, spl, 1_000; progress=false) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) end @testset "without `m` as random" begin model = gdemo(1.5, 2.0) | (m=7 / 6,) vns = (@varname(s),) - alg = Gibbs(vns => MH()) + spl = Gibbs(vns => MH()) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end end @@ -818,7 +800,7 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default - alg = Gibbs( + spl = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), @@ -832,25 +814,23 @@ end @varname(mu2) ) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # Sample! Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) + chain = sample(MoGtest_default, spl, 1000; progress=false) check_MoGtest_default(chain; atol=0.2) end @testset "CSMC + ESS (usage of implicit varname)" begin rng = Random.default_rng() model = MoGtest_default_z_vector - alg = Gibbs(@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS()) + spl = Gibbs(@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS()) vns = ( @varname(z[1]), @varname(z[2]), @@ -860,18 +840,16 @@ end @varname(mu2) ) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # Sample! Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) + chain = sample(model, spl, 1000; progress=false) check_MoGtest_default_z_vector(chain; atol=0.2) end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index bc801b27c..c6b5af216 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -4,7 +4,7 @@ using ..Models: gdemo_default using ..NumericalTests: check_gdemo, check_numerical using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample -using DynamicPPL: DynamicPPL, Sampler +using DynamicPPL: DynamicPPL import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -236,7 +236,7 @@ using Turing 10; nadapts=0, discard_adapt=false, - initial_state=chn1.info.samplerstate, + initial_state=loadstate(chn1), ) # if chn2 uses initial_state, its first sample should be somewhere around 5. if # initial_state isn't used, it will be sampled from [-2, 2] so this test should fail @@ -295,11 +295,10 @@ using Turing end @testset "getstepsize: Turing.jl#2400" begin - algs = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)] - @testset "$(alg)" for alg in algs + spls = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)] + @testset "$(spl)" for spl in spls # Construct a HMC state by taking a single step - spl = Sampler(alg) - hmc_state = DynamicPPL.initialstep( + hmc_state = Turing.Inference.initialstep( Random.default_rng(), gdemo_default, spl, diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 7fb39b966..7c19f022b 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -4,7 +4,6 @@ using AdvancedMH: AdvancedMH using Distributions: Bernoulli, Dirichlet, Exponential, InverseGamma, LogNormal, MvNormal, Normal, sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using LinearAlgebra: I using Random: Random using StableRNGs: StableRNG @@ -116,7 +115,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) end model = M(zeros(2), I, 1) - sampler = Inference.Sampler(MH()) + sampler = MH() dt, vt = Inference.dist_val_tuple(sampler, DynamicPPL.VarInfo(model)) @@ -231,24 +230,21 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Don't link when no proposals are given since we're using priors # as proposals. vi = deepcopy(vi_base) - alg = MH() - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) + spl = MH() + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) @test !DynamicPPL.is_transformed(vi) # Link if proposal is `AdvancedHM.RandomWalkProposal` vi = deepcopy(vi_base) d = length(vi_base[:]) - alg = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I))) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) + spl = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I))) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) @test DynamicPPL.is_transformed(vi) # Link if ALL proposals are `AdvancedHM.RandomWalkProposal`. vi = deepcopy(vi_base) - alg = MH(:s => AdvancedMH.RandomWalkProposal(Normal())) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) + spl = MH(:s => AdvancedMH.RandomWalkProposal(Normal())) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) @test DynamicPPL.is_transformed(vi) # Don't link if at least one proposal is NOT `RandomWalkProposal`. @@ -256,12 +252,11 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # are linked! I.e. resolve https://github.com/TuringLang/Turing.jl/issues/1583. # https://github.com/TuringLang/Turing.jl/pull/1582#issuecomment-817148192 vi = deepcopy(vi_base) - alg = MH( + spl = MH( :m => AdvancedMH.StaticProposal(Normal()), :s => AdvancedMH.RandomWalkProposal(Normal()), ) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) @test !DynamicPPL.is_transformed(vi) end diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index 38b22219c..1a2288402 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -1,7 +1,6 @@ module RepeatSamplerTests using ..Models: gdemo_default -using DynamicPPL: Sampler using MCMCChains: MCMCChains using Random: Xoshiro using Test: @test, @testset @@ -17,7 +16,7 @@ using Turing # Use Xoshiro instead of StableRNGs as the output should always be # similar regardless of what kind of random seed is used (as long # as there is a random seed). - for sampler in [MH(), Sampler(HMC(0.01, 4))] + for sampler in [MH(), HMC(0.01, 4)] chn1 = sample( Xoshiro(0), gdemo_default, diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 66ad03212..e08137109 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -18,13 +18,6 @@ using Turing @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGHMC} - - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) - @test alg isa SGHMC - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGHMC} end @testset "sghmc inference" begin @@ -43,13 +36,6 @@ end @testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGLD} - - alg = SGLD(; stepsize=PolynomialStepsize(0.25)) - @test alg isa SGLD - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGLD} end @testset "sgld inference" begin diff --git a/test/runtests.jl b/test/runtests.jl index 5fb6b2141..81b4bdde2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,7 @@ end end @testset "samplers (without AD)" verbose = true begin + @timeit_include("mcmc/abstractmcmc.jl") @timeit_include("mcmc/particle_mcmc.jl") @timeit_include("mcmc/emcee.jl") @timeit_include("mcmc/ess.jl") diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl index c7371bc00..a2ca123b1 100644 --- a/test/test_utils/sampler.jl +++ b/test/test_utils/sampler.jl @@ -1,5 +1,8 @@ module SamplerTestUtils +using AbstractMCMC +using AbstractPPL +using DynamicPPL using Random using Turing using Test @@ -42,4 +45,54 @@ function test_rng_respected(spl) @test isapprox(chn1[:y], chn2[:y]) end +""" + test_sampler_analytical(models, sampler, args...; kwargs...) + +Test that `sampler` produces correct marginal posterior means on each model in `models`. + +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` +and `sampler` to produce a `chain`, and then checks the chain's mean for every (leaf) +varname `vn` against the corresponding value returned by +`DynamicPPL.TestUtils.posterior_mean` for each model. + +For this to work, each model in `models` must have a known analytical posterior mean +that can be computed by `DynamicPPL.TestUtils.posterior_mean`. + +# Arguments +- `models`: A collection of instances of `DynamicPPL.Model` to test on. +- `sampler`: The `AbstractMCMC.AbstractSampler` to test. +- `args...`: Arguments forwarded to `sample`. + +# Keyword arguments +- `varnames_filter`: A filter to apply to `varnames(model)`, allowing comparison for only + a subset of the varnames. +- `atol=1e-1`: Absolute tolerance used in `@test`. +- `rtol=1e-3`: Relative tolerance used in `@test`. +- `kwargs...`: Keyword arguments forwarded to `sample`. +""" +function test_sampler_analytical( + models, + sampler::AbstractMCMC.AbstractSampler, + args...; + varnames_filter=Returns(true), + atol=1e-1, + rtol=1e-3, + sampler_name=typeof(sampler), + kwargs..., +) + @testset "$(sampler_name) on $(nameof(model))" for model in models + chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + target_values = DynamicPPL.TestUtils.posterior_mean(model) + for vn in filter(varnames_filter, DynamicPPL.TestUtils.varnames(model)) + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = mean(chain[Symbol(vn_leaf)]) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end + end + end +end + end