diff --git a/HISTORY.md b/HISTORY.md index 23a686a73..5c3d1da41 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,30 @@ # 0.41.0 +## DynamicPPL 0.38 + +Turing.jl v0.41 brings with it all the underlying changes in DynamicPPL 0.38. + +The only user-facing difference is that initial parameters for MCMC sampling must now be specified in a different form. +You still need to use the `initial_params` keyword argument to `sample`, but the allowed values are different. +For almost all samplers in Turing.jl (except `Emcee`) this should now be a `DynamicPPL.AbstractInitStrategy`. + +TODO LINK TO DPPL DOCS WHEN THIS IS LIVE + +There are three kinds of initialisation strategies provided out of the box with Turing.jl (they are exported so you can use these directly with `using Turing`): + + - `InitFromPrior()`: Sample from the prior distribution. This is the default for most samplers in Turing.jl (if you don't specify `initial_params`). + - `InitFromUniform(a, b)`: Sample uniformly from `[a, b]` in linked space. This is the default for Hamiltonian samplers. If `a` and `b` are not specified it defaults to `[-2, 2]`, which preserves the behaviour in previous versions (and mimics that of Stan). + - `InitFromParams(p)`: Explicitly provide a set of initial parameters. **Note: `p` must be either a `NamedTuple` or a `Dict{<:VarName}`; it can no longer be a `Vector`.** Parameters must be provided in unlinked space, even if the sampler later performs linking. + +This change is made because Vectors are semantically ambiguous. +It is not clear which element of the vector corresponds to which variable in the model, nor is it clear whether the parameters are in linked or unlinked space. +Previously, both of these would depend on the internal structure of the VarInfo, which is an implementation detail. +In contrast, the behaviour of `Dict`s and `NamedTuple`s is invariant to the ordering of variables and it is also easier for readers to understand which variable is being set to which value. + +If you were previously using `varinfo[:]` to extract a vector of initial parameters, you can now use `Dict(k => varinfo[k] for k in keys(varinfo)` to extract a Dict of initial parameters. + +## Initial step in MCMC sampling + HMC and NUTS samplers no longer take an extra single step before starting the chain. This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided). diff --git a/Project.toml b/Project.toml index e679949d4..b867f4771 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" -TuringOptimExt = "Optim" +TuringOptimExt = ["Optim", "AbstractPPL"] [compat] ADTypes = "1.9" @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.37.2" +DynamicPPL = "0.38" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" @@ -90,3 +90,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/docs/src/api.md b/docs/src/api.md index 0b8351eb3..62c8d41c2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -75,6 +75,16 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu | `RepeatSampler` | [`Turing.Inference.RepeatSampler`](@ref) | A sampler that runs multiple times on the same variable | | `externalsampler` | [`Turing.Inference.externalsampler`](@ref) | Wrap an external sampler for use in Turing | +### Initialisation strategies + +Turing.jl provides several strategies to initialise parameters for models. + +| Exported symbol | Documentation | Description | +|:----------------- |:--------------------------------------- |:--------------------------------------------------------------- | +| `InitFromPrior` | [`DynamicPPL.InitFromPrior`](@extref) | Obtain initial parameters from the prior distribution | +| `InitFromUniform` | [`DynamicPPL.InitFromUniform`](@extref) | Obtain initial parameters by sampling uniformly in linked space | +| `InitFromParams` | [`DynamicPPL.InitFromParams`](@extref) | Manually specify (possibly a subset of) initial parameters | + ### Variational inference See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough. diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 2c4bd0898..dac11ff5a 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -44,10 +44,6 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} stepsize::S end -function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS}) - return DynamicPPL.SampleFromUniform() -end - function DynamicPPL.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index 0f755988e..21aecafbe 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -1,6 +1,7 @@ module TuringOptimExt using Turing: Turing +using AbstractPPL: AbstractPPL import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation using Optim: Optim @@ -186,7 +187,7 @@ function _optimize( f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype ) vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) vns_vals_iter = mapreduce(collect, vcat, iters) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/Turing.jl b/src/Turing.jl index 0cdbe2458..b3412cf55 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -73,7 +73,10 @@ using DynamicPPL: conditioned, to_submodel, LogDensityFunction, - @addlogprob! + @addlogprob!, + InitFromPrior, + InitFromUniform, + InitFromParams using StatsBase: predict using OrderedCollections: OrderedDict @@ -148,6 +151,10 @@ export fix, unfix, OrderedDict, # OrderedCollections + # Initialisation strategies for models + InitFromPrior, + InitFromUniform, + InitFromParams, # Point estimates - Turing.Optimisation # The MAP and MLE exports are only needed for the Optim.jl interface. maximum_a_posteriori, diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 53bf6dbc0..7e1456696 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -24,8 +24,6 @@ using DynamicPPL: getdist, Model, Sampler, - SampleFromPrior, - SampleFromUniform, DefaultContext, set_flag!, unset_flag! @@ -59,8 +57,6 @@ export InferenceAlgorithm, Hamiltonian, StaticHamiltonian, AdaptiveHamiltonian, - SampleFromUniform, - SampleFromPrior, MH, ESS, Emcee, @@ -84,6 +80,8 @@ export InferenceAlgorithm, # Abstract interface for inference algorithms # ############################################### +const TURING_CHAIN_TYPE = MCMCChains.Chains + include("algorithm.jl") #################### @@ -262,13 +260,13 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) dicts = map(ts) do t # In general getparams returns a dict of VarName => values. We need to also # split it up into constituent elements using - # `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl + # `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl # won't understand it. vals = getparams(model, t) nms_and_vs = if isempty(vals) Tuple{VarName,Any}[] else - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) mapreduce(collect, vcat, iters) end nms = map(first, nms_and_vs) @@ -315,11 +313,10 @@ end getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. -# This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{Transition,AbstractVarInfo}}, + ts::Vector{<:Transition}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, + spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, state, chain_type::Type{MCMCChains.Chains}; save_state=false, @@ -378,11 +375,10 @@ function AbstractMCMC.bundle_samples( return sort_chain ? sort(chain) : chain end -# This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{Transition,AbstractVarInfo}}, + ts::Vector{<:Transition}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, + spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, state, chain_type::Type{Vector{NamedTuple}}; kwargs..., diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index edd563885..63cff1243 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,9 +1,9 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - # TODO(DPPL0.38/penelopeysm): use InitContext - spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) - return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) + new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) + new_model = DynamicPPL.contextualize(model, new_context) + return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true) end function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) return _check_model(model) diff --git a/src/mcmc/algorithm.jl b/src/mcmc/algorithm.jl index d45ae0d4a..725b6afbf 100644 --- a/src/mcmc/algorithm.jl +++ b/src/mcmc/algorithm.jl @@ -11,4 +11,6 @@ this wrapping occurs automatically. """ abstract type InferenceAlgorithm end -DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains +function DynamicPPL.init_strategy(sampler::Sampler{<:InferenceAlgorithm}) + return DynamicPPL.InitFromPrior() +end diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 98ed20b40..48caffc6f 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -31,12 +31,21 @@ struct EmceeState{V<:AbstractVarInfo,S} states::S end +# Utility function to tetrieve the number of walkers +_get_n_walkers(e::Emcee) = e.ensemble.n_walkers +_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg) + +# Because Emcee expects n_walkers initialisations, we need to override this +function DynamicPPL.init_strategy(spl::Sampler{<:Emcee}) + return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl)) +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:Emcee}; resume_from=nothing, - initial_params=nothing, + initial_params, kwargs..., ) if resume_from !== nothing @@ -45,23 +54,19 @@ function AbstractMCMC.step( end # Sample from the prior - n = spl.alg.ensemble.n_walkers - vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n] + n = _get_n_walkers(spl) + vis = [VarInfo(rng, model) for _ in 1:n] # Update the parameters if provided. - if initial_params !== nothing - length(initial_params) == n || - throw(ArgumentError("initial parameters have to be specified for each walker")) - vis = map(vis, initial_params) do vi, init - # TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!! - vi = DynamicPPL.initialize_parameters!!(vi, init, model) - - # Update log joint probability. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) - ) - last(DynamicPPL.evaluate!!(spl_model, vi)) - end + if !( + initial_params isa AbstractVector{<:DynamicPPL.AbstractInitStrategy} && + length(initial_params) == n + ) + err_msg = "initial_params for `Emcee` must be a vector of `DynamicPPL.AbstractInitStrategy`, with length equal to the number of walkers ($n)" + throw(ArgumentError(err_msg)) + end + vis = map(vis, initial_params) do vi, strategy + last(DynamicPPL.init!!(rng, model, vi, strategy)) end # Compute initial transition and states. diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 3afd91607..d89d25cf9 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -82,23 +82,8 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - varinfo = p.varinfo - # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? - # TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model, - # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason - # why we had to use the 'del' flag before this was because - # SampleFromPrior() wouldn't overwrite existing variables. - # The main problem I'm rather unsure about is ESS-within-Gibbs. The - # current implementation I think makes sure to only resample the variables - # that 'belong' to the current ESS sampler. InitContext on the other hand - # would resample all variables in the model (??) Need to think about this - # carefully. - vns = keys(varinfo) - for vn in vns - set_flag!(varinfo, vn, "del") - end - p.model(rng, varinfo) - return varinfo[:] + _, vi = DynamicPPL.init!!(rng, p.model, p.varinfo, DynamicPPL.InitFromPrior()) + return vi[:] end # Mean of prior distribution diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index af31e0243..0755e4160 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -117,7 +117,7 @@ function AbstractMCMC.step( model::DynamicPPL.Model, sampler_wrapper::Sampler{<:ExternalSampler}; initial_state=nothing, - initial_params=nothing, + initial_params=DynamicPPL.init_strategy(sampler_wrapper.alg.sampler), kwargs..., ) alg = sampler_wrapper.alg @@ -125,17 +125,17 @@ function AbstractMCMC.step( # Initialise varinfo with initial params and link the varinfo if needed. varinfo = DynamicPPL.VarInfo(model) + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params) + if requires_unconstrained_space(alg) - if initial_params !== nothing - # If we have initial parameters, we need to set the varinfo before linking. - varinfo = DynamicPPL.link(DynamicPPL.unflatten(varinfo, initial_params), model) - # Extract initial parameters in unconstrained space. - initial_params = varinfo[:] - else - varinfo = DynamicPPL.link(varinfo, model) - end + varinfo = DynamicPPL.link(varinfo, model) end + # We need to extract the vectorised initial_params, because the later call to + # AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params` + # to be a vector. + initial_params_vector = varinfo[:] + # Construct LogDensityFunction f = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype @@ -144,7 +144,11 @@ function AbstractMCMC.step( # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing transition_inner, state_inner = AbstractMCMC.step( - rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs... + rng, + AbstractMCMC.LogDensityModel(f), + sampler; + initial_params=initial_params_vector, + kwargs..., ) else transition_inner, state_inner = AbstractMCMC.step( @@ -152,7 +156,7 @@ function AbstractMCMC.step( AbstractMCMC.LogDensityModel(f), sampler, initial_state; - initial_params, + initial_params=initial_params_vector, kwargs..., ) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 17bc88153..e8837ec0b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -47,7 +47,7 @@ A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. `target_varnames` is a a tuple of `VarName`s that the current component sampler -is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume` +is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!` calls to its child context. For other variables, their values will be fixed to the values they have in `global_varinfo`. @@ -140,7 +140,9 @@ function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) end # Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) +function DynamicPPL.tilde_assume!!( + context::GibbsContext, right::Distribution, vn::VarName, vi::DynamicPPL.AbstractVarInfo +) child_context = DynamicPPL.childcontext(context) # Note that `child_context` may contain `PrefixContext`s -- in which case @@ -175,7 +177,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) return if is_target_varname(context, vn) # Fall back to the default behavior. - DynamicPPL.tilde_assume(child_context, right, vn, vi) + DynamicPPL.tilde_assume!!(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # This branch means that a different sampler is supposed to handle this # variable. From the perspective of this sampler, this variable is @@ -191,50 +193,10 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, new_global_vi = DynamicPPL.tilde_assume( - child_context, - DynamicPPL.SampleFromPrior(), - right, - vn, - get_global_varinfo(context), - ) - set_global_varinfo!(context, new_global_vi) - value, vi - end -end - -# As above but with an RNG. -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi -) - # See comment in the above, rng-less version of this method for an explanation. - child_context = DynamicPPL.childcontext(context) - vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) - - return if is_target_varname(context, vn) - # This branch means that that `sampler` is supposed to handle - # this variable. We can thus use its default behaviour, with - # the 'local' sampler-specific VarInfo. - DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) - elseif has_conditioned_gibbs(context, vn) - # This branch means that a different sampler is supposed to handle this - # variable. From the perspective of this sampler, this variable is - # conditioned on, so we can just treat it as an observation. - # The only catch is that the value that we need is to be obtained from - # the global VarInfo (since the local VarInfo has no knowledge of it). - # Note that tilde_observe!! will trigger resampling in particle methods - # for variables that are handled by other Gibbs component samplers. - val = get_conditioned_gibbs(context, vn) - DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) - else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. - value, new_global_vi = DynamicPPL.tilde_assume( - rng, - child_context, - DynamicPPL.SampleFromPrior(), + value, new_global_vi = DynamicPPL.tilde_assume!!( + # child_context might be a PrefixContext so we have to be careful to not + # overwrite it. + DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext()), right, vn, get_global_varinfo(context), @@ -352,19 +314,9 @@ This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated h support calling both step and step_warmup as the initial step. DynamicPPL initialstep is incompatible with step_warmup. """ -function initial_varinfo(rng, model, spl, initial_params) +function initial_varinfo(rng, model, spl, initial_params::DynamicPPL.AbstractInitStrategy) vi = DynamicPPL.default_varinfo(rng, model, spl) - - # Update the parameters if provided. - if initial_params !== nothing - vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi)) - end + _, vi = DynamicPPL.init!!(rng, model, vi, initial_params) return vi end @@ -372,7 +324,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params=nothing, + initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl), kwargs..., ) alg = spl.alg @@ -397,7 +349,7 @@ function AbstractMCMC.step_warmup( rng::Random.AbstractRNG, model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params=nothing, + initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl), kwargs..., ) alg = spl.alg @@ -434,7 +386,7 @@ function gibbs_initialstep_recursive( samplers, vi, states=(); - initial_params=nothing, + initial_params, kwargs..., ) # End recursion @@ -445,13 +397,6 @@ function gibbs_initialstep_recursive( varnames, varname_vecs_tail... = varname_vecs sampler, samplers_tail... = samplers - # Get the initial values for this component sampler. - initial_params_local = if initial_params === nothing - nothing - else - DynamicPPL.subset(vi, varnames)[:] - end - # Construct the conditioned model. conditioned_model, context = make_conditional(model, varnames, vi) @@ -462,7 +407,7 @@ function gibbs_initialstep_recursive( sampler; # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, + initial_params=initial_params, kwargs..., ) new_vi_local = get_varinfo(new_state) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 363508e70..e13019db0 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -80,7 +80,7 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT; adtype=adtype) end -DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() +DynamicPPL.init_strategy(::Sampler{<:Hamiltonian}) = DynamicPPL.InitFromUniform() # Handle setting `nadapts` and `discard_initial` function AbstractMCMC.sample( @@ -88,8 +88,9 @@ function AbstractMCMC.sample( model::DynamicPPL.Model, sampler::Sampler{<:AdaptiveHamiltonian}, N::Integer; - chain_type=DynamicPPL.default_chain_type(sampler), + chain_type=TURING_CHAIN_TYPE, resume_from=nothing, + initial_params=DynamicPPL.init_strategy(sampler), initial_state=DynamicPPL.loadstate(resume_from), progress=PROGRESS[], nadapts=sampler.alg.n_adapts, @@ -123,6 +124,7 @@ function AbstractMCMC.sample( progress=progress, nadapts=_nadapts, discard_initial=_discard_initial, + initial_params=initial_params, kwargs..., ) else @@ -137,6 +139,7 @@ function AbstractMCMC.sample( nadapts=0, discard_adapt=false, discard_initial=0, + initial_params=initial_params, kwargs..., ) end @@ -146,7 +149,8 @@ function find_initial_params( rng::Random.AbstractRNG, model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo, - hamiltonian::AHMC.Hamiltonian; + hamiltonian::AHMC.Hamiltonian, + init_strategy::DynamicPPL.AbstractInitStrategy; max_attempts::Int=1000, ) varinfo = deepcopy(varinfo) # Don't mutate @@ -157,15 +161,10 @@ function find_initial_params( isfinite(z) && return varinfo, z attempts == 10 && - @warn "failed to find valid initial parameters in $(attempts) tries; consider providing explicit initial parameters using the `initial_params` keyword" + @warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword" # Resample and try again. - # NOTE: varinfo has to be linked to make sure this samples in unconstrained space - varinfo = last( - DynamicPPL.evaluate_and_sample!!( - rng, model, varinfo, DynamicPPL.SampleFromUniform() - ), - ) + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy) end # if we failed to find valid initial parameters, error @@ -179,7 +178,9 @@ function DynamicPPL.initialstep( model::AbstractModel, spl::Sampler{<:Hamiltonian}, vi_original::AbstractVarInfo; - initial_params=nothing, + # the initial_params kwarg is always passed on from sample(), cf. DynamicPPL + # src/sampler.jl, so we don't need to provide a default value here + initial_params::DynamicPPL.AbstractInitStrategy, nadapts=0, verbose::Bool=true, kwargs..., @@ -200,13 +201,15 @@ function DynamicPPL.initialstep( lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) - # If no initial parameters are provided, resample until the log probability - # and its gradient are finite. Otherwise, just use the existing parameters. - vi, z = if initial_params === nothing - find_initial_params(rng, model, vi, hamiltonian) - else - vi, AHMC.phasepoint(rng, theta, hamiltonian) - end + # Note that there is already one round of 'initialisation' before we reach this step, + # inside DynamicPPL's `AbstractMCMC.step` implementation. That leads to a possible issue + # that this `find_initial_params` function might override the parameters set by the + # user. + # Luckily for us, `find_initial_params` always checks if the logp and its gradient are + # finite. If it is already finite with the params inside the current `vi`, it doesn't + # attempt to find new ones. This means that the parameters passed to `sample()` will be + # respected instead of being overridden here. + vi, z = find_initial_params(rng, model, vi, hamiltonian, initial_params) theta = vi[:] # Find good eps if not provided one @@ -471,15 +474,6 @@ function make_ahmc_kernel(alg::NUTS, ϵ) ) end -#### -#### Compiler interface, i.e. tilde operators. -#### -function DynamicPPL.assume( - rng, ::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi -) - return DynamicPPL.assume(dist, vn, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 319e424fc..932e6e0f4 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -26,8 +26,6 @@ sample(gdemo([1.5, 2]), IS(), 1000) """ struct IS <: InferenceAlgorithm end -DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler - function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) @@ -37,7 +35,9 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) - vi = VarInfo(rng, model, spl) + model = DynamicPPL.setleafcontext(model, ISContext(rng)) + _, vi = DynamicPPL.evaluate!!(model, DynamicPPL.VarInfo()) + vi = DynamicPPL.typed_varinfo(vi) return Transition(model, vi, nothing), nothing end @@ -46,13 +46,25 @@ function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end -function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi) +struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R +end +DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf() + +function DynamicPPL.tilde_assume!!( + ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) if haskey(vi, vn) r = vi[vn] else - r = rand(rng, dist) + r = rand(ctx.rng, dist) vi = push!!(vi, vn, r, dist) end vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) return r, vi end +function DynamicPPL.tilde_observe!!( + ::ISContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo +) + return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 863db559c..2ccceb3d7 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -178,8 +178,6 @@ get_varinfo(s::MHState) = s.varinfo # Utility functions # ##################### -# TODO(DPPL0.38/penelopeysm): This function should no longer be needed -# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -207,15 +205,24 @@ end # NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems # interface in that it gets evaluated with a NamedTuple. Hence we need this # method just to deal with MH. -# TODO(DPPL0.38/penelopeysm): Check the extent to which this method is actually -# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, -# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). -# In general, we should much prefer to either (1) conform to the -# LogDensityProblems interface or (2) use VarNames anyway. function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) + # Note that the NamedTuple `x` does NOT conform to the structure required for + # `InitFromParams`. In particular, for models that look like this: + # + # @model function f() + # v = Vector{Vector{Float64}} + # v[1] ~ MvNormal(zeros(2), I) + # end + # + # `InitFromParams` will expect Dict(@varname(v[1]) => [x1, x2]), but `x` will have the + # format `(v = [x1, x2])`. Hence we still need this `set_namedtuple!` function. + # + # In general `init!!(f.model, vi, InitFromParams(x))` will work iff the model only + # contains 'basic' varnames. set_namedtuple!(vi, x) - vi_new = last(DynamicPPL.evaluate!!(f.model, vi)) + # Update log probability. + _, vi_new = DynamicPPL.evaluate!!(f.model, vi) lj = f.getlogdensity(vi_new) return lj end @@ -329,13 +336,11 @@ function propose!!( prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false) # Make a new transition. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, spl, model.context) - ) + model = DynamicPPL.setleafcontext(model, MHContext(rng)) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) @@ -366,13 +371,11 @@ function propose!!( prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, spl, model.context) - ) + model = DynamicPPL.setleafcontext(model, MHContext(rng)) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) @@ -410,13 +413,25 @@ function AbstractMCMC.step( return Transition(model, new_state.varinfo, nothing), new_state end -#### -#### Compiler interface, i.e. tilde operators. -#### -function DynamicPPL.assume( - rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi +struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R +end +DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf() + +function DynamicPPL.tilde_assume!!( + context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Allow MH to sample new variables from the prior if it's not already present in the + # VarInfo. + dispatch_ctx = if haskey(vi, vn) + DynamicPPL.DefaultContext() + else + DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior()) + end + return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi) +end +function DynamicPPL.tilde_observe!!( + ::MHContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo ) - # Just defer to `SampleFromPrior`. - retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) - return retval + return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index e80ec527b..e792ba930 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -36,30 +36,28 @@ function unset_all_del!(vi::AbstractVarInfo) return nothing end -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: - AdvancedPS.AbstractGenericModel +struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R +end +DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf() + +struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M - sampler::S varinfo::V evaluator::E + resample::Bool end function TracedModel( - model::Model, - sampler::AbstractSampler, - varinfo::AbstractVarInfo, - rng::Random.AbstractRNG, + model::Model, varinfo::AbstractVarInfo, rng::Random.AbstractRNG, resample::Bool ) - spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) - spl_model = DynamicPPL.contextualize(model, spl_context) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) - if kwargs !== nothing && !isempty(kwargs) - error( - "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", - ) - end - evaluator = (spl_model.f, args...) - return TracedModel(spl_model, sampler, varinfo, evaluator) + model = DynamicPPL.setleafcontext(model, ParticleMCMCContext(rng)) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) + isempty(kwargs) || error( + "Particle sampling methods do not currently support models with keyword arguments.", + ) + evaluator = (model.f, args...) + return TracedModel(model, varinfo, evaluator, resample) end function AdvancedPS.advance!( @@ -75,16 +73,9 @@ function AdvancedPS.delete_retained!(trace::TracedModel) # This method is called if, during a CSMC update, we perform a resampling # and choose the reference particle as the trajectory to carry on from. # In such a case, we need to ensure that when we continue sampling (i.e. - # the next time we hit tilde_assume), we don't use the values in the + # the next time we hit tilde_assume!!), we don't use the values in the # reference particle but rather sample new values. - # - # Here, we indiscriminately set the 'del' flag for all variables in the - # VarInfo. This is slightly overkill: it is not necessary to set the 'del' - # flag for variables that were already sampled. However, it allows us to - # avoid keeping track of which variables were sampled, which leads to many - # simplifications in the VarInfo data structure. - set_all_del!(trace.varinfo) - return trace + return TracedModel(trace.model, trace.varinfo, trace.evaluator, true) end function AdvancedPS.reset_model(trace::TracedModel) @@ -151,8 +142,9 @@ function AbstractMCMC.sample( model::DynamicPPL.Model, sampler::Sampler{<:SMC}, N::Integer; - chain_type=DynamicPPL.default_chain_type(sampler), + chain_type=TURING_CHAIN_TYPE, resume_from=nothing, + initial_params=DynamicPPL.init_strategy(sampler), initial_state=DynamicPPL.loadstate(resume_from), progress=PROGRESS[], kwargs..., @@ -164,6 +156,7 @@ function AbstractMCMC.sample( sampler, N; chain_type=chain_type, + initial_params=initial_params, progress=progress, nparticles=N, kwargs..., @@ -175,6 +168,7 @@ function AbstractMCMC.sample( sampler, N; chain_type, + initial_params=initial_params, initial_state, progress=progress, nparticles=N, @@ -198,7 +192,7 @@ function DynamicPPL.initialstep( # Create a new set of particles. particles = AdvancedPS.ParticleContainer( - [AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) for _ in 1:nparticles], + [AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for _ in 1:nparticles], AdvancedPS.TracedRNG(), rng, ) @@ -317,13 +311,14 @@ function DynamicPPL.initialstep( kwargs..., ) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - # Reset the VarInfo before new sweep - set_all_del!(vi) # Create a new set of particles num_particles = spl.alg.nparticles particles = AdvancedPS.ParticleContainer( - [AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) for _ in 1:num_particles], + [ + AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for + _ in 1:num_particles + ], AdvancedPS.TracedRNG(), rng, ) @@ -351,17 +346,13 @@ function AbstractMCMC.step( vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create reference particle for which the samples will be retained. - unset_all_del!(vi) - reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) - - # For all other particles, do not retain the variables but resample them. - set_all_del!(vi) + reference = AdvancedPS.forkr(AdvancedPS.Trace(model, vi, state.rng, false)) # Create a new set of particles. num_particles = spl.alg.nparticles x = map(1:num_particles) do i if i != num_particles - return AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) + return AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) else return reference end @@ -383,11 +374,7 @@ function AbstractMCMC.step( return transition, PGState(_vi, newreference.rng) end -function DynamicPPL.use_threadsafe_eval( - ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo -) - return false -end +DynamicPPL.use_threadsafe_eval(::ParticleMCMCContext, ::AbstractVarInfo) = false """ get_trace_local_varinfo_maybe(vi::AbstractVarInfo) @@ -407,7 +394,24 @@ function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) end """ - get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) + get_trace_local_resampled_maybe(fallback_resampled::Bool) + +Get the `Trace` local `resampled` if one exists. + +If executed within a `TapedTask`, return the `resampled` stored in the "taped globals" of +the task, otherwise return `fallback_resampled`. +""" +function get_trace_local_resampled_maybe(fallback_resampled::Bool) + trace = try + Libtask.get_taped_globals(Any).other + catch e + e == KeyError(:task_variable) ? nothing : rethrow(e) + end + return (trace === nothing ? fallback_resampled : trace.model.f.resample)::Bool +end + +""" + get_trace_local_rng_maybe(rng::Random.AbstractRNG) Get the `Trace` local rng if one exists. @@ -446,30 +450,22 @@ function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) return nothing end -function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo +function DynamicPPL.tilde_assume!!( + ctx::ParticleMCMCContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) arg_vi_id = objectid(vi) vi = get_trace_local_varinfo_maybe(vi) using_local_vi = objectid(vi) == arg_vi_id - trng = get_trace_local_rng_maybe(rng) - - if ~haskey(vi, vn) - r = rand(trng, dist) - vi = push!!(vi, vn, r, dist) - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent - # TODO(mhauru): - # The below is the only line that differs from assume called on SampleFromPrior. - # Could we just call assume on SampleFromPrior with a specific rng? - r = rand(trng, dist) - vi[vn] = DynamicPPL.tovec(r) + trng = get_trace_local_rng_maybe(ctx.rng) + resample = get_trace_local_resampled_maybe(true) + + dispatch_ctx = if ~haskey(vi, vn) || resample + DynamicPPL.InitContext(trng, DynamicPPL.InitFromPrior()) else - r = vi[vn] + DynamicPPL.DefaultContext() end - - vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) + x, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, vi) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, @@ -477,17 +473,21 @@ function DynamicPPL.assume( if !using_local_vi set_trace_local_varinfo_maybe(vi) end - return r, vi + return x, vi end function DynamicPPL.tilde_observe!!( - ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi + ::ParticleMCMCContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, ) arg_vi_id = objectid(vi) vi = get_trace_local_varinfo_maybe(vi) using_local_vi = objectid(vi) == arg_vi_id - left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, @@ -500,13 +500,10 @@ end # Convenient constructor function AdvancedPS.Trace( - model::Model, - sampler::Sampler{<:Union{SMC,PG}}, - varinfo::AbstractVarInfo, - rng::AdvancedPS.TracedRNG, + model::Model, varinfo::AbstractVarInfo, rng::AdvancedPS.TracedRNG, resample::Bool ) newvarinfo = deepcopy(varinfo) - tmodel = TracedModel(model, sampler, newvarinfo, rng) + tmodel = TracedModel(model, newvarinfo, rng, resample) newtrace = AdvancedPS.Trace(tmodel, rng) return newtrace end @@ -573,7 +570,6 @@ Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} # Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 2ead40ced..c5228d8fc 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,19 +12,14 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - # TODO(DPPL0.38/penelopeysm): replace with init!! - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) - ) - vi = VarInfo() vi = DynamicPPL.setaccs!!( - vi, + DynamicPPL.VarInfo(), ( DynamicPPL.ValuesAsInModelAccumulator(true), DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator(), ), ) - _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior()) return Transition(model, vi, nothing; reevaluate=false), nothing end diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index fa2eca96d..5669a27b5 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -81,3 +81,58 @@ function AbstractMCMC.step_warmup( end return transition, state end + +# Need some extra leg work to make RepeatSampler work seamlessly with DynamicPPL models + +# samplers, instead of generic AbstractMCMC samplers. + +function DynamicPPL.init_strategy(spl::RepeatSampler{<:Sampler}) + return DynamicPPL.init_strategy(spl.sampler) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::RepeatSampler{<:Sampler}, + N::Integer; + initial_params=DynamicPPL.init_strategy(sampler), + chain_type=TURING_CHAIN_TYPE, + progress=PROGRESS[], + kwargs..., +) + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + N; + initial_params=initial_params, + chain_type=chain_type, + progress=progress, + kwargs..., + ) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::RepeatSampler{<:Sampler}, + ensemble::AbstractMCMC.AbstractMCMCEnsemble, + N::Integer, + n_chains::Integer; + initial_params=fill(DynamicPPL.init_strategy(sampler), n_chains), + chain_type=TURING_CHAIN_TYPE, + progress=PROGRESS[], + kwargs..., +) + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + ensemble, + N, + n_chains; + initial_params=initial_params, + chain_type=chain_type, + progress=progress, + kwargs..., + ) +end diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 19c52c381..c073a4597 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -2,6 +2,7 @@ module Optimisation using ..Turing using NamedArrays: NamedArrays +using AbstractPPL: AbstractPPL using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Optimization: Optimization @@ -320,7 +321,7 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) # m.values, but they are more convenient to filter when they are VarNames rather than # Symbols. vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) vns_and_vals = mapreduce(collect, vcat, iters) varnames = collect(map(first, vns_and_vals)) # For each symbol s in var_symbols, pick all the values from m.values for which the @@ -351,7 +352,7 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u) # `getparams` performs invlinking if needed vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) vns_vals_iter = mapreduce(collect, vcat, iters) syms = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) @@ -507,10 +508,8 @@ function estimate_mode( kwargs..., ) if check_model - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true) + new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + DynamicPPL.check_model(new_model, DynamicPPL.VarInfo(); error_on_failure=true) end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) diff --git a/test/Project.toml b/test/Project.toml index ba7a83be1..9671918e9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,6 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.37.2" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -77,3 +76,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/ad.jl b/test/ad.jl index dcfe4ef46..9524199dc 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -154,31 +154,23 @@ end # context, and then call check_adtype on the result before returning the results from the # child context. -function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) - check_adtype(context, vi) - return value, vi -end - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi +function DynamicPPL.tilde_assume!!( + context::ADTypeCheckContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) - value, vi = DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume!!(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) return value, vi end -function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) - left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) - check_adtype(context, vi) - return left, vi -end - -function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!( + context::ADTypeCheckContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) left, vi = DynamicPPL.tilde_observe!!( - DynamicPPL.childcontext(context), sampler, right, left, vi + DynamicPPL.childcontext(context), right, left, vn, vi ) check_adtype(context, vi) return left, vi diff --git a/test/essential/container.jl b/test/essential/container.jl index 124637aab..100cf0432 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -22,7 +22,7 @@ using Turing vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = test() - trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) + trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) # Make sure the backreference from taped_globals to the trace is in place. @test trace.model.ctask.taped_globals.other === trace @@ -48,7 +48,7 @@ using Turing sampler = Sampler(PG(10)) model = normal() - trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) + trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) newtrace = AdvancedPS.forkr(trace) # Catch broken replay mechanism diff --git a/test/ext/OptimInterface.jl b/test/ext/OptimInterface.jl index 8fb9e2b1a..721e255f3 100644 --- a/test/ext/OptimInterface.jl +++ b/test/ext/OptimInterface.jl @@ -2,6 +2,7 @@ module OptimInterfaceTests using ..Models: gdemo_default using Distributions.FillArrays: Zeros +using AbstractPPL: AbstractPPL using DynamicPPL: DynamicPPL using LinearAlgebra: I using Optim: Optim @@ -124,7 +125,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) @test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol = 0.05 end end @@ -159,7 +160,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) if model.f in allowed_incorrect_mle @test isfinite(get(result_true, vn_leaf)) else diff --git a/test/mcmc/emcee.jl b/test/mcmc/emcee.jl index b9a041d78..03861f17e 100644 --- a/test/mcmc/emcee.jl +++ b/test/mcmc/emcee.jl @@ -34,18 +34,21 @@ using Turing nwalkers = 250 spl = Emcee(nwalkers, 2.0) - # No initial parameters, with im- and explicit `initial_params=nothing` Random.seed!(1234) chain1 = sample(gdemo_default, spl, 1) Random.seed!(1234) - chain2 = sample(gdemo_default, spl, 1; initial_params=nothing) + chain2 = sample(gdemo_default, spl, 1) @test Array(chain1) == Array(chain2) + initial_nt = DynamicPPL.InitFromParams((s=2.0, m=1.0)) # Initial parameters have to be specified for every walker - @test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=[2.0, 1.0]) + @test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=initial_nt) + @test_throws r"must be a vector of" sample( + gdemo_default, spl, 1; initial_params=initial_nt + ) # Initial parameters - chain = sample(gdemo_default, spl, 1; initial_params=fill([2.0, 1.0], nwalkers)) + chain = sample(gdemo_default, spl, 1; initial_params=fill(initial_nt, nwalkers)) @test chain[:s] == fill(2.0, 1, nwalkers) @test chain[:m] == fill(1.0, 1, nwalkers) end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 1e1be9b45..ad1ca4ba2 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -2,6 +2,7 @@ module ESSTests using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default using ..NumericalTests: check_MoGtest_default, check_numerical +using ..SamplerTestUtils: test_rng_respected using Distributions: Normal, sample using DynamicPPL: DynamicPPL using DynamicPPL: Sampler @@ -38,6 +39,12 @@ using Turing c3 = sample(gdemo_default, s2, N) end + @testset "RNG is respected" begin + test_rng_respected(ESS()) + test_rng_respected(Gibbs(:x => ESS(), :y => MH())) + test_rng_respected(Gibbs(:x => ESS(), :y => ESS())) + end + @testset "ESS inference" begin @info "Starting ESS inference tests" seed = 23 @@ -108,8 +115,12 @@ using Turing spl_x = Gibbs(@varname(z) => NUTS(), @varname(x) => ESS()) spl_xy = Gibbs(@varname(z) => NUTS(), (@varname(x), @varname(y)) => ESS()) - @test sample(StableRNG(23), xy(), spl_xy, num_samples).value ≈ - sample(StableRNG(23), x12(), spl_x, num_samples).value + chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples) + chn2 = sample(StableRNG(23), x12(), spl_x, num_samples) + + @test mean(chn1[:z]) ≈ mean(chn2[:z]) atol = 0.05 + @test mean(chn1[:x]) ≈ mean(chn2["x[1]"]) atol = 0.05 + @test mean(chn1[:y]) ≈ mean(chn2["x[2]"]) atol = 0.05 end end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 38b9b0660..d1e72a94e 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -45,6 +45,8 @@ using Turing.Inference: AdvancedHMC rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, sampler::MySampler; + # This initial_params should be an AbstractVector because the model is just a + # LogDensityModel, not a DynamicPPL.Model initial_params::AbstractVector, kwargs..., ) @@ -82,7 +84,10 @@ using Turing.Inference: AdvancedHMC model = test_external_sampler() a, b = 0.5, 0.0 - chn = sample(model, externalsampler(MySampler()), 10; initial_params=[a, b]) + # This `initial_params` should be an InitStrategy + chn = sample( + model, externalsampler(MySampler()), 10; initial_params=InitFromParams((a=a, b=b)) + ) @test chn isa MCMCChains.Chains @test all(chn[:a] .== a) @test all(chn[:b] .== b) @@ -156,10 +161,7 @@ function Distributions._rand!( ) model = d.model varinfo = deepcopy(d.varinfo) - for vn in keys(varinfo) - DynamicPPL.set_flag!(varinfo, vn, "del") - end - DynamicPPL.evaluate!!(model, varinfo, DynamicPPL.SamplingContext(rng)) + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, DynamicPPL.InitFromPrior()) x .= varinfo[:] return x end @@ -170,16 +172,24 @@ function initialize_mh_with_prior_proposal(model) ) end -function test_initial_params( - model, sampler, initial_params=DynamicPPL.VarInfo(model)[:]; kwargs... -) +function test_initial_params(model, sampler; kwargs...) + # Generate some parameters. + dict = DynamicPPL.values_as(DynamicPPL.VarInfo(model), Dict) + init_strategy = DynamicPPL.InitFromParams(dict) + # Execute the transition with two different RNGs and check that the resulting - # parameter values are the same. + # parameter values are the same. This ensures that the `initial_params` are + # respected (i.e., regardless of the RNG, the first step should always return + # the same parameters). rng1 = Random.MersenneTwister(42) rng2 = Random.MersenneTwister(43) - transition1, _ = AbstractMCMC.step(rng1, model, sampler; initial_params, kwargs...) - transition2, _ = AbstractMCMC.step(rng2, model, sampler; initial_params, kwargs...) + transition1, _ = AbstractMCMC.step( + rng1, model, sampler; initial_params=init_strategy, kwargs... + ) + transition2, _ = AbstractMCMC.step( + rng2, model, sampler; initial_params=init_strategy, kwargs... + ) vn_to_val1 = DynamicPPL.OrderedDict(transition1.θ) vn_to_val2 = DynamicPPL.OrderedDict(transition2.θ) for vn in union(keys(vn_to_val1), keys(vn_to_val2)) @@ -198,16 +208,18 @@ end sampler_ext = DynamicPPL.Sampler( externalsampler(sampler; adtype, unconstrained=true) ) - # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. + + # TODO: AdvancedHMC samplers do not return the initial parameters as the first + # step, so `test_initial_params` will fail. This should be fixed upstream in + # AdvancedHMC.jl. For reasons that are beyond my current understanding, this was + # done in https://github.com/TuringLang/AdvancedHMC.jl/pull/366, but the PR + # was then reverted and never looked at again. # @testset "initial_params" begin # test_initial_params(model, sampler_ext; n_adapts=0) # end sample_kwargs = ( - n_adapts=1_000, - discard_initial=1_000, - # FIXME: Remove this once we can run `test_initial_params` above. - initial_params=DynamicPPL.VarInfo(model)[:], + n_adapts=1_000, discard_initial=1_000, initial_params=InitFromUniform() ) @testset "inference" begin diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 634fcc98d..2c5774773 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -693,13 +693,9 @@ end num_chains = 4 # Determine initial parameters to make comparison as fair as possible. + # posterior_mean returns a NamedTuple so we can plug it in directly. posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) + initial_params = fill(InitFromParams(posterior_mean), num_chains) # Sampler to use for Gibbs components. hmc = HMC(0.1, 32) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 3328838a9..e6341d4b6 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -177,7 +177,11 @@ using Turing @testset "$spl_name" for (spl_name, spl) in (("HMC", HMC(0.1, 10)), ("NUTS", NUTS())) chain = sample( - demo_norm(), spl, 5; discard_adapt=false, initial_params=(x=init_x,) + demo_norm(), + spl, + 5; + discard_adapt=false, + initial_params=InitFromParams((x=init_x,)), ) @test chain[:x][1] == init_x chain = sample( @@ -187,7 +191,7 @@ using Turing 5, 5; discard_adapt=false, - initial_params=(fill((x=init_x,), 5)), + initial_params=(fill(InitFromParams((x=init_x,)), 5)), ) @test all(chain[:x][1, :] .== init_x) end @@ -202,12 +206,11 @@ using Turing end end - @test_logs ( - :warn, - "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode = :any begin - sample(demo_warn_initial_params(), NUTS(), 5) - end + # verbose=false to suppress the initial step size notification, which messes with + # the test + @test_logs (:warn, r"consider providing a different initialisation strategy") sample( + demo_warn_initial_params(), NUTS(), 5; verbose=false + ) end @testset "error for impossible model" begin @@ -253,7 +256,8 @@ using Turing model = buggy_model() num_samples = 1_000 - chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) + initial_params = InitFromParams((lb=0.5, ub=1.75, x=1.0)) + chain = sample(model, NUTS(), num_samples; initial_params=initial_params) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how @@ -275,7 +279,11 @@ using Turing # Construct a HMC state by taking a single step spl = Sampler(alg) hmc_state = DynamicPPL.initialstep( - Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default) + Random.default_rng(), + gdemo_default, + spl, + DynamicPPL.VarInfo(gdemo_default); + initial_params=InitFromUniform(), )[2] # Check that we can obtain the current step size @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 diff --git a/test/mcmc/is.jl b/test/mcmc/is.jl index 2811e9c86..3d557c022 100644 --- a/test/mcmc/is.jl +++ b/test/mcmc/is.jl @@ -1,63 +1,52 @@ module ISTests -using Distributions: Normal, sample using DynamicPPL: logpdf using Random: Random +using StableRNGs: StableRNG using StatsFuns: logsumexp using Test: @test, @testset using Turing @testset "is.jl" begin - function reference(n) - as = Vector{Float64}(undef, n) - bs = Vector{Float64}(undef, n) - logps = Vector{Float64}(undef, n) + @testset "numerical accuracy" begin + function reference(n) + rng = StableRNG(468) + as = Vector{Float64}(undef, n) + bs = Vector{Float64}(undef, n) - for i in 1:n - as[i], bs[i], logps[i] = reference() + for i in 1:n + as[i] = rand(rng, Normal(4, 5)) + bs[i] = rand(rng, Normal(as[i], 1)) + end + # logevidence = logsumexp(logps) - log(n) + return (as=as, bs=bs) end - logevidence = logsumexp(logps) - log(n) - return (as=as, bs=bs, logps=logps, logevidence=logevidence) - end - - function reference() - x = rand(Normal(4, 5)) - y = rand(Normal(x, 1)) - loglik = logpdf(Normal(x, 2), 3) + logpdf(Normal(y, 2), 1.5) - return x, y, loglik - end - - @model function normal() - a ~ Normal(4, 5) - 3 ~ Normal(a, 2) - b ~ Normal(a, 1) - 1.5 ~ Normal(b, 2) - return a, b - end - - alg = IS() - seed = 0 - n = 10 + @model function normal() + a ~ Normal(4, 5) + 3 ~ Normal(a, 2) + b ~ Normal(a, 1) + 1.5 ~ Normal(b, 2) + return a, b + end - model = normal() - for i in 1:100 - Random.seed!(seed) - ref = reference(n) + function expected_loglikelihoods(as, bs) + return logpdf.(Normal.(as, 2), 3) .+ logpdf.(Normal.(bs, 2), 1.5) + end - Random.seed!(seed) - chain = sample(model, alg, n; check_model=false) - sampled = get(chain, [:a, :b, :loglikelihood]) + alg = IS() + N = 1000 + model = normal() + chain = sample(StableRNG(468), model, alg, N) + ref = reference(N) - @test vec(sampled.a) == ref.as - @test vec(sampled.b) == ref.bs - @test vec(sampled.loglikelihood) == ref.logps - @test chain.logevidence == ref.logevidence + @test isapprox(mean(chain[:a]), mean(ref.as); atol=0.1) + @test isapprox(mean(chain[:b]), mean(ref.bs); atol=0.1) + @test isapprox(chain[:loglikelihood], expected_loglikelihoods(chain[:a], chain[:b])) + @test isapprox(chain.logevidence, logsumexp(chain[:loglikelihood]) - log(N)) end @testset "logevidence" begin - Random.seed!(100) - @model function test() a ~ Normal(0, 1) x ~ Bernoulli(1) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 70810e164..e0e5d51a6 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -49,7 +49,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Set the initial parameters, because if we get unlucky with the initial state, # these chains are too short to converge to reasonable numbers. discard_initial = 1_000 - initial_params = [1.0, 1.0] + initial_params = InitFromParams((s=1.0, m=1.0)) @testset "gdemo_default" begin alg = MH() @@ -72,7 +72,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) chain = sample( StableRNG(seed), gdemo_default, alg, 10_000; discard_initial, initial_params ) - check_gdemo(chain; atol=0.1) + check_gdemo(chain; atol=0.15) end @testset "MoGtest_default with Gibbs" begin @@ -81,13 +81,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @varname(mu1) => MH((:mu1, GKernel(1))), @varname(mu2) => MH((:mu2, GKernel(1))), ) + initial_params = InitFromParams(( + mu1=1.0, mu2=1.0, z1=0.0, z2=0.0, z3=1.0, z4=1.0 + )) chain = sample( StableRNG(seed), MoGtest_default, gibbs, 500; discard_initial=100, - initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0], + initial_params=initial_params, ) check_MoGtest_default(chain; atol=0.2) end @@ -184,7 +187,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Test that the small variance version is actually smaller. variance_small = var(diff(Array(chn_small["μ[1]"]); dims=1)) variance_big = var(diff(Array(chn_big["μ[1]"]); dims=1)) - @test variance_small < variance_big / 1_000.0 + @test variance_small < variance_big / 100.0 end @testset "vector of multivariate distributions" begin diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index d848627d7..38b22219c 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -2,8 +2,8 @@ module RepeatSamplerTests using ..Models: gdemo_default using DynamicPPL: Sampler -using MCMCChains: Chains -using StableRNGs: StableRNG +using MCMCChains: MCMCChains +using Random: Xoshiro using Test: @test, @testset using Turing @@ -14,10 +14,12 @@ using Turing num_samples = 10 num_chains = 2 - rng = StableRNG(0) + # Use Xoshiro instead of StableRNGs as the output should always be + # similar regardless of what kind of random seed is used (as long + # as there is a random seed). for sampler in [MH(), Sampler(HMC(0.01, 4))] chn1 = sample( - copy(rng), + Xoshiro(0), gdemo_default, sampler, MCMCThreads(), @@ -27,15 +29,16 @@ using Turing ) repeat_sampler = RepeatSampler(sampler, num_repeats) chn2 = sample( - copy(rng), + Xoshiro(0), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, - num_chains; - chain_type=Chains, + num_chains, ) # isequal to avoid comparing `missing`s in chain stats + @test chn1 isa MCMCChains.Chains + @test chn2 isa MCMCChains.Chains @test isequal(chn1.value, chn2.value) end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index ee943270c..66ad03212 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -56,7 +56,7 @@ end rng = StableRNG(1) chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000) - check_gdemo(chain; atol=0.2) + check_gdemo(chain; atol=0.25) # Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh) v = get(chain, [:SGLD_stepsize, :s, :m]) diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 269a71acb..d93895e28 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -1,6 +1,7 @@ module OptimisationTests using ..Models: gdemo, gdemo_default +using AbstractPPL: AbstractPPL using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL @@ -495,7 +496,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) @test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol = 0.05 end end @@ -534,7 +535,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) if model.f in allowed_incorrect_mle @test isfinite(get(result_true, vn_leaf)) else diff --git a/test/stdlib/distributions.jl b/test/stdlib/distributions.jl index e6ce5794d..56c2e59b1 100644 --- a/test/stdlib/distributions.jl +++ b/test/stdlib/distributions.jl @@ -130,7 +130,14 @@ using Turing @model m() = x ~ dist - chn = sample(StableRNG(468), m(), HMC(0.05, 20), n_samples) + seed = if dist isa GeneralizedExtremeValue + # GEV is prone to giving really wacky results that are quite + # seed-dependent. + StableRNG(469) + else + StableRNG(468) + end + chn = sample(seed, m(), HMC(0.05, 20), n_samples) # Numerical tests. check_dist_numerical( diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl index 32a3647f9..c7371bc00 100644 --- a/test/test_utils/sampler.jl +++ b/test/test_utils/sampler.jl @@ -1,5 +1,6 @@ module SamplerTestUtils +using Random using Turing using Test @@ -24,4 +25,21 @@ function test_chain_logp_metadata(spl) @test chn[:lp] ≈ chn[:logprior] + chn[:loglikelihood] end +""" +Check that sampling is deterministic when using the same RNG seed. +""" +function test_rng_respected(spl) + @model function f(z) + # put at least two variables here so that we can meaningfully test Gibbs + x ~ Normal() + y ~ Normal() + return z ~ Normal(x + y) + end + model = f(2.0) + chn1 = sample(Xoshiro(468), model, spl, 100) + chn2 = sample(Xoshiro(468), model, spl, 100) + @test isapprox(chn1[:x], chn2[:x]) + @test isapprox(chn1[:y], chn2[:y]) +end + end