Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,11 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
params::P
fallback::S
function InitFromParams(
params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing}
params::AbstractDict{<:VarName},
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(),
)
return new{typeof(params),typeof(fallback)}(params, fallback)
end
function InitFromParams(params::AbstractDict{<:VarName})
return InitFromParams(params, InitFromPrior())
end
function InitFromParams(
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
)
Expand Down
37 changes: 35 additions & 2 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
"""
init_strategy(::AbstractSampler) = InitFromPrior()

"""
_convert_initial_params(initial_params)

Convert `initial_params` to an `AbstractInitStrategy` if it is not already one.
"""
_convert_initial_params(initial_params::AbstractInitStrategy) = initial_params
function _convert_initial_params(nt::NamedTuple)
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
return InitFromParams(nt)
end
function _convert_initial_params(d::AbstractDict{<:VarName})
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
return InitFromParams(d)
end
function _convert_initial_params(::AbstractVector)
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
throw(ArgumentError(errmsg))
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
Expand All @@ -63,7 +82,13 @@ function AbstractMCMC.sample(
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; initial_params, initial_state, kwargs...
rng,
model,
sampler,
N;
initial_params=_convert_initial_params(initial_params),
initial_state,
kwargs...,
)
end

Expand All @@ -79,7 +104,15 @@ function AbstractMCMC.sample(
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs...
rng,
model,
sampler,
parallel,
N,
nchains;
initial_params=map(_convert_initial_params, initial_params),
initial_state,
kwargs...,
)
end

Expand Down
26 changes: 24 additions & 2 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,27 @@
end
end

# check that Vector no longer works
@test_throws ArgumentError sample(
model, sampler, 1; initial_params=[4, -1], progress=false
)
@test_throws ArgumentError sample(
model, sampler, 1; initial_params=[missing, -1], progress=false
)

# model with two variables: initialization s = 4, m = -1
@model function twovars()
s ~ InverseGamma(2, 3)
return m ~ Normal(0, sqrt(s))
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
let inits = InitFromParams((; s=4, m=-1))
for inits in (
InitFromParams((s=4, m=-1)),
(s=4, m=-1),
InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)),
Dict(@varname(s) => 4, @varname(m) => -1),
)
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
Expand All @@ -169,7 +182,16 @@
end

# set only m = -1
for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1)))
for inits in (
InitFromParams((; s=missing, m=-1)),
InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)),
(; s=missing, m=-1),
Dict(@varname(s) => missing, @varname(m) => -1),
InitFromParams((; m=-1)),
InitFromParams(Dict(@varname(m) => -1)),
(; m=-1),
Dict(@varname(m) => -1),
)
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]
Expand Down
Loading