Skip to content

Commit 20de15a

Browse files
committed
Fix initial_params when sampling Turing models
1 parent b891a52 commit 20de15a

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/sampler.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ function AbstractMCMC.sample(
5959
sampler::Sampler,
6060
N::Integer;
6161
chain_type=default_chain_type(sampler),
62+
initial_params=init_strategy(sampler),
6263
resume_from=nothing,
6364
initial_state=loadstate(resume_from),
6465
kwargs...,
6566
)
6667
return AbstractMCMC.mcmcsample(
67-
rng, model, sampler, N; chain_type, initial_state, kwargs...
68+
rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs...
6869
)
6970
end
7071

@@ -75,21 +76,31 @@ function AbstractMCMC.sample(
7576
parallel::AbstractMCMC.AbstractMCMCEnsemble,
7677
N::Integer,
7778
nchains::Integer;
79+
initial_params=fill(init_strategy(sampler), nchains),
7880
chain_type=default_chain_type(sampler),
7981
resume_from=nothing,
8082
initial_state=loadstate(resume_from),
8183
kwargs...,
8284
)
8385
return AbstractMCMC.mcmcsample(
84-
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
86+
rng,
87+
model,
88+
sampler,
89+
parallel,
90+
N,
91+
nchains;
92+
chain_type,
93+
initial_params,
94+
initial_state,
95+
kwargs...,
8596
)
8697
end
8798

8899
function AbstractMCMC.step(
89100
rng::Random.AbstractRNG,
90101
model::Model,
91102
spl::Sampler;
92-
initial_params::AbstractInitStrategy=init_strategy(spl),
103+
initial_params::AbstractInitStrategy,
93104
kwargs...,
94105
)
95106
# Generate the default varinfo. Note that any parameters inside this varinfo

0 commit comments

Comments
 (0)