Skip to content

Commit c3ba043

Browse files
committed
Enable NamedTuple/Dict initialisation
1 parent 908d402 commit c3ba043

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

src/contexts/init.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,11 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
9494
params::P
9595
fallback::S
9696
function InitFromParams(
97-
params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing}
97+
params::AbstractDict{<:VarName},
98+
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(),
9899
)
99100
return new{typeof(params),typeof(fallback)}(params, fallback)
100101
end
101-
function InitFromParams(params::AbstractDict{<:VarName})
102-
return InitFromParams(params, InitFromPrior())
103-
end
104102
function InitFromParams(
105103
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
106104
)

src/sampler.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,26 @@ sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
5353
"""
5454
init_strategy(::AbstractSampler) = InitFromPrior()
5555

56+
"""
57+
_convert_initial_params(initial_params)
58+
59+
Convert `initial_params` to an `AbstractInitStrategy` if it is not already one.
60+
"""
61+
_convert_initial_params(initial_params::AbstractInitStrategy) = initial_params
62+
function _convert_initial_params(nt::NamedTuple)
63+
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
64+
return InitFromParams(nt)
65+
end
66+
function _convert_initial_params(d::AbstractDict{<:VarName})
67+
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
68+
return InitFromParams(d)
69+
end
70+
function _convert_initial_params(::AbstractVector)
71+
return error(
72+
"`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code.",
73+
)
74+
end
75+
5676
function AbstractMCMC.sample(
5777
rng::Random.AbstractRNG,
5878
model::Model,
@@ -63,7 +83,13 @@ function AbstractMCMC.sample(
6383
kwargs...,
6484
)
6585
return AbstractMCMC.mcmcsample(
66-
rng, model, sampler, N; initial_params, initial_state, kwargs...
86+
rng,
87+
model,
88+
sampler,
89+
N;
90+
initial_params=_convert_initial_params(initial_params),
91+
initial_state,
92+
kwargs...,
6793
)
6894
end
6995

@@ -79,7 +105,15 @@ function AbstractMCMC.sample(
79105
kwargs...,
80106
)
81107
return AbstractMCMC.mcmcsample(
82-
rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs...
108+
rng,
109+
model,
110+
sampler,
111+
parallel,
112+
N,
113+
nchains;
114+
initial_params=map(_convert_initial_params, initial_params),
115+
initial_state,
116+
kwargs...,
83117
)
84118
end
85119

test/sampler.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@
145145
end
146146
model = twovars()
147147
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
148-
let inits = InitFromParams((; s=4, m=-1))
148+
for inits in (
149+
InitFromParams((s=4, m=-1)),
150+
(s=4, m=-1),
151+
InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)),
152+
Dict(@varname(s) => 4, @varname(m) => -1),
153+
)
149154
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
150155
@test chain[1].metadata.s.vals == [4]
151156
@test chain[1].metadata.m.vals == [-1]
@@ -169,7 +174,15 @@
169174
end
170175

171176
# set only m = -1
172-
for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1)))
177+
for inits in (
178+
InitFromParams((; s=missing, m=-1)),
179+
InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)),
180+
(; s=missing, m=-1),
181+
Dict(@varname(s) => missing, @varname(m) => -1),
182+
InitFromParams((; m=-1)),
183+
InitFromParams(Dict(@varname(m) => -1)),
184+
(; m=-1)Dict(@varname(m) => -1),
185+
)
173186
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
174187
@test !ismissing(chain[1].metadata.s.vals[1])
175188
@test chain[1].metadata.m.vals == [-1]

0 commit comments

Comments
 (0)