diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 4baca1b57..aee561896 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -94,13 +94,11 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS params::P fallback::S function InitFromParams( - params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} + params::AbstractDict{<:VarName}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), ) return new{typeof(params),typeof(fallback)}(params, fallback) end - function InitFromParams(params::AbstractDict{<:VarName}) - return InitFromParams(params, InitFromPrior()) - end function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) diff --git a/src/sampler.jl b/src/sampler.jl index ed1b86321..4cc56d8ff 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -53,6 +53,25 @@ sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. """ init_strategy(::AbstractSampler) = InitFromPrior() +""" + _convert_initial_params(initial_params) + +Convert `initial_params` to an `AbstractInitStrategy` if it is not already one. +""" +_convert_initial_params(initial_params::AbstractInitStrategy) = initial_params +function _convert_initial_params(nt::NamedTuple) + @info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead." + return InitFromParams(nt) +end +function _convert_initial_params(d::AbstractDict{<:VarName}) + @info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead." + return InitFromParams(d) +end +function _convert_initial_params(::AbstractVector) + errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code." + throw(ArgumentError(errmsg)) +end + function AbstractMCMC.sample( rng::Random.AbstractRNG, model::Model, @@ -63,7 +82,13 @@ function AbstractMCMC.sample( kwargs..., ) return AbstractMCMC.mcmcsample( - rng, model, sampler, N; initial_params, initial_state, kwargs... + rng, + model, + sampler, + N; + initial_params=_convert_initial_params(initial_params), + initial_state, + kwargs..., ) end @@ -79,7 +104,15 @@ function AbstractMCMC.sample( kwargs..., ) return AbstractMCMC.mcmcsample( - rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs... + rng, + model, + sampler, + parallel, + N, + nchains; + initial_params=map(_convert_initial_params, initial_params), + initial_state, + kwargs..., ) end diff --git a/test/sampler.jl b/test/sampler.jl index 8be54901d..3fe7f2b07 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -138,6 +138,14 @@ end end + # check that Vector no longer works + @test_throws ArgumentError sample( + model, sampler, 1; initial_params=[4, -1], progress=false + ) + @test_throws ArgumentError sample( + model, sampler, 1; initial_params=[missing, -1], progress=false + ) + # model with two variables: initialization s = 4, m = -1 @model function twovars() s ~ InverseGamma(2, 3) @@ -145,7 +153,12 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - let inits = InitFromParams((; s=4, m=-1)) + for inits in ( + InitFromParams((s=4, m=-1)), + (s=4, m=-1), + InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)), + Dict(@varname(s) => 4, @varname(m) => -1), + ) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @@ -169,7 +182,16 @@ end # set only m = -1 - for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1))) + for inits in ( + InitFromParams((; s=missing, m=-1)), + InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)), + (; s=missing, m=-1), + Dict(@varname(s) => missing, @varname(m) => -1), + InitFromParams((; m=-1)), + InitFromParams(Dict(@varname(m) => -1)), + (; m=-1), + Dict(@varname(m) => -1), + ) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1]