Skip to content
Open
10 changes: 5 additions & 5 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
stepsize::S
end

function DynamicPPL.initialstep(
function Turing.Inference.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
spl::DynamicNUTS,
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
Expand All @@ -59,7 +59,7 @@ function DynamicPPL.initialstep(

# Define log-density function.
ℓ = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
)

# Perform initial step.
Expand All @@ -80,14 +80,14 @@ end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
spl::DynamicNUTS,
state::DynamicNUTSState;
kwargs...,
)
# Compute next sample.
vi = state.vi
ℓ = state.logdensity
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Create next sample and state.
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ export
maximum_a_posteriori,
maximum_likelihood,
MAP,
MLE
MLE,
# Chain save/resume
loadstate

end
37 changes: 15 additions & 22 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using DynamicPPL:
getsym,
getdist,
Model,
Sampler,
DefaultContext
using Distributions, Libtask, Bijectors
using DistributionsAD: VectorOfMultivariate
Expand Down Expand Up @@ -50,8 +49,7 @@ import Random
import MCMCChains
import StatsBase: predict

export InferenceAlgorithm,
Hamiltonian,
export Hamiltonian,
StaticHamiltonian,
AdaptiveHamiltonian,
MH,
Expand All @@ -71,15 +69,16 @@ export InferenceAlgorithm,
RepeatSampler,
Prior,
predict,
externalsampler
externalsampler,
init_strategy,
loadstate

###############################################
# Abstract interface for inference algorithms #
###############################################

const TURING_CHAIN_TYPE = MCMCChains.Chains
#########################################
# Generic AbstractMCMC methods dispatch #
#########################################

include("algorithm.jl")
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
include("abstractmcmc.jl")

####################
# Sampler wrappers #
Expand Down Expand Up @@ -312,8 +311,8 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
model::DynamicPPL.Model,
spl::AbstractSampler,
state,
chain_type::Type{MCMCChains.Chains};
save_state=false,
Expand Down Expand Up @@ -374,8 +373,8 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
model::DynamicPPL.Model,
spl::AbstractSampler,
state,
chain_type::Type{Vector{NamedTuple}};
kwargs...,
Expand Down Expand Up @@ -416,7 +415,7 @@ function group_varnames_by_symbol(vns)
return d
end

function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
return setinfo(c, merge(nt, c.info))
end
Expand All @@ -435,18 +434,12 @@ include("sghmc.jl")
include("emcee.jl")
include("prior.jl")

#################################################
# Generic AbstractMCMC methods dispatch #
#################################################

include("abstractmcmc.jl")

################
# Typing tools #
################

function DynamicPPL.get_matching_type(
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
spl::Union{PG,SMC}, vi, ::Type{TV}
) where {T,N,TV<:Array{T,N}}
return Array{T,N}
end
Expand Down
142 changes: 125 additions & 17 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model)
new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
end
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
function _check_model(model::DynamicPPL.Model, ::AbstractSampler)
return _check_model(model)
end

"""
Turing.Inference.init_strategy(spl::AbstractSampler)

Get the default initialization strategy for a given sampler `spl`, i.e. how initial
parameters for sampling are chosen if not specified by the user. By default, this is
`InitFromPrior()`, which samples initial parameters from the prior distribution.
"""
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()

"""
_convert_initial_params(initial_params)

Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
throw a useful error message.
"""
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
function _convert_initial_params(nt::NamedTuple)
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
return DynamicPPL.InitFromParams(nt)
end
function _convert_initial_params(d::AbstractDict{<:VarName})
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
return DynamicPPL.InitFromParams(d)
end
function _convert_initial_params(::AbstractVector{<:Real})
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
throw(ArgumentError(errmsg))
end
function _convert_initial_params(@nospecialize(_::Any))
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
throw(ArgumentError(errmsg))
end

"""
default_varinfo(rng, model, sampler)

Return a default varinfo object for the given `model` and `sampler`.
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').
"""
function default_varinfo(
rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler
)
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
# immediately overwritten by a subsequent call to `init!!`. The reason why we
# _do_ create a varinfo with parameters here (as opposed to simply returning
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
# typed VarInfo would fail. This can happen if two VarNames have different types
# but share the same symbol (e.g. `x.a` and `x.b`).
# TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments
# and return an empty VarInfo instead.
return DynamicPPL.typed_varinfo(VarInfo(rng, model))
end

#########################################
# Default definitions for the interface #
#########################################

const DEFAULT_CHAIN_TYPE = MCMCChains.Chains

function AbstractMCMC.sample(
model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs...
model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs...
)
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
model::DynamicPPL.Model,
spl::AbstractSampler,
N::Integer;
initial_params=init_strategy(spl),
check_model::Bool=true,
chain_type=DEFAULT_CHAIN_TYPE,
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg), N; chain_type, kwargs...)
check_model && _check_model(model, spl)
return AbstractMCMC.mcmcsample(
rng,
model,
spl,
N;
initial_params=_convert_initial_params(initial_params),
chain_type,
kwargs...,
)
end

function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
model::DynamicPPL.Model,
alg::AbstractSampler,
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
Expand All @@ -47,18 +107,66 @@ function AbstractMCMC.sample(
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractSampler,
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
chain_type=DEFAULT_CHAIN_TYPE,
check_model::Bool=true,
initial_params=fill(init_strategy(spl), n_chains),
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(
rng, model, Sampler(alg), ensemble, N, n_chains; chain_type, kwargs...
check_model && _check_model(model, spl)
if !(initial_params isa AbstractVector) || length(initial_params) != n_chains
errmsg = "`initial_params` must be an AbstractVector of length `n_chains`; one element per chain"
throw(ArgumentError(errmsg))
end
return AbstractMCMC.mcmcsample(
rng,
model,
spl,
ensemble,
N,
n_chains;
chain_type,
initial_params=map(_convert_initial_params, initial_params),
kwargs...,
)
end

function loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
),
)
end
return chain.info[:samplerstate]
end

# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
function initialstep end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractSampler;
initial_params,
kwargs...,
)
# Generate the default varinfo. Note that any parameters inside this varinfo
# will be immediately overwritten by the next call to `init!!`.
vi = default_varinfo(rng, model, spl)

# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
# parameters provided must be in unlinked space (when inserted into the
# varinfo, they will be adjusted to match the linking status of the
# varinfo).
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params)

# Call the actual function that does the first step.
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end
Comment on lines +150 to +172
Copy link
Member Author

@penelopeysm penelopeysm Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method of step is actually a little bit evil. It used to be less bad because it only applied to Sampler{<:InferenceAlgorithm}, but now it applies to all AbstractSampler, which actually does cause some method ambiguities (which I've pointed out in my other comments).

On top of that, this is just generally a bit inflexible when it comes to warmup steps since it's only defined as a method for step and not step_warmup.

I think that in the next version of Turing this method should be removed. However, I've opted to preserve it for now because I don't want to make too many conceptual changes in this PR (the diff is already too large).

16 changes: 0 additions & 16 deletions src/mcmc/algorithm.jl

This file was deleted.

Loading
Loading