Skip to content

Commit ba4da83

Browse files
committed
Fix more tests
1 parent 20f9e97 commit ba4da83

File tree

6 files changed

+74
-9
lines changed

6 files changed

+74
-9
lines changed

src/mcmc/algorithm.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ this wrapping occurs automatically.
1212
abstract type InferenceAlgorithm end
1313

1414
DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains
15+
16+
function DynamicPPL.init_strategy(sampler::Sampler{<:InferenceAlgorithm})
17+
return DynamicPPL.InitFromPrior()
18+
end

src/mcmc/emcee.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ _get_n_walkers(e::Emcee) = e.ensemble.n_walkers
3636
_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg)
3737

3838
# Because Emcee expects n_walkers initialisations, we need to override this
39-
DynamicPPL.init_strategy(spl::Sampler{<:Emcee}) = fill(InitFromPrior(), _get_n_walkers(spl))
39+
function DynamicPPL.init_strategy(spl::Sampler{<:Emcee})
40+
return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl))
41+
end
4042

4143
function AbstractMCMC.step(
4244
rng::Random.AbstractRNG,

src/mcmc/hmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ function AbstractMCMC.sample(
9090
N::Integer;
9191
chain_type=DynamicPPL.default_chain_type(sampler),
9292
resume_from=nothing,
93+
initial_params=DynamicPPL.init_strategy(sampler),
9394
initial_state=DynamicPPL.loadstate(resume_from),
9495
progress=PROGRESS[],
9596
nadapts=sampler.alg.n_adapts,

src/mcmc/is.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function AbstractMCMC.step(
3737
)
3838
model = DynamicPPL.setleafcontext(model, ISContext(rng))
3939
_, vi = DynamicPPL.evaluate!!(model, DynamicPPL.VarInfo())
40-
vi = DynamicPPL.typed_varinfo(vi, model)
40+
vi = DynamicPPL.typed_varinfo(vi)
4141
return Transition(model, vi, nothing), nothing
4242
end
4343

src/mcmc/repeat_sampler.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,58 @@ function AbstractMCMC.step_warmup(
8181
end
8282
return transition, state
8383
end
84+
85+
# Need some extra leg work to make RepeatSampler work seamlessly with DynamicPPL models +
86+
# samplers, instead of generic AbstractMCMC samplers.
87+
88+
function DynamicPPL.init_strategy(spl::RepeatSampler{<:Sampler})
89+
return DynamicPPL.init_strategy(spl.sampler)
90+
end
91+
92+
function AbstractMCMC.sample(
93+
rng::AbstractRNG,
94+
model::DynamicPPL.Model,
95+
sampler::RepeatSampler{<:Sampler},
96+
N::Integer;
97+
initial_params=DynamicPPL.init_strategy(sampler),
98+
chain_type=MCMCChains.Chains,
99+
progress=PROGRESS[],
100+
kwargs...,
101+
)
102+
return AbstractMCMC.mcmcsample(
103+
rng,
104+
model,
105+
sampler,
106+
N;
107+
initial_params=initial_params,
108+
chain_type=chain_type,
109+
progress=progress,
110+
kwargs...,
111+
)
112+
end
113+
114+
function AbstractMCMC.sample(
115+
rng::AbstractRNG,
116+
model::DynamicPPL.Model,
117+
sampler::RepeatSampler{<:Sampler},
118+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
119+
N::Integer,
120+
n_chains::Integer;
121+
initial_params=fill(DynamicPPL.init_strategy(sampler), n_chains),
122+
chain_type=MCMCChains.Chains,
123+
progress=PROGRESS[],
124+
kwargs...,
125+
)
126+
return AbstractMCMC.mcmcsample(
127+
rng,
128+
model,
129+
sampler,
130+
ensemble,
131+
N,
132+
n_chains;
133+
initial_params=initial_params,
134+
chain_type=chain_type,
135+
progress=progress,
136+
kwargs...,
137+
)
138+
end

test/mcmc/repeat_sampler.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ module RepeatSamplerTests
22

33
using ..Models: gdemo_default
44
using DynamicPPL: Sampler
5-
using MCMCChains: Chains
6-
using StableRNGs: StableRNG
5+
using MCMCChains: MCMCChains
6+
using Random: Xoshiro
77
using Test: @test, @testset
88
using Turing
99

@@ -14,10 +14,12 @@ using Turing
1414
num_samples = 10
1515
num_chains = 2
1616

17-
rng = StableRNG(0)
17+
# Use Xoshiro instead of StableRNGs as the output should always be
18+
# similar regardless of what kind of random seed is used (as long
19+
# as there is a random seed).
1820
for sampler in [MH(), Sampler(HMC(0.01, 4))]
1921
chn1 = sample(
20-
copy(rng),
22+
Xoshiro(0),
2123
gdemo_default,
2224
sampler,
2325
MCMCThreads(),
@@ -27,15 +29,16 @@ using Turing
2729
)
2830
repeat_sampler = RepeatSampler(sampler, num_repeats)
2931
chn2 = sample(
30-
copy(rng),
32+
Xoshiro(0),
3133
gdemo_default,
3234
repeat_sampler,
3335
MCMCThreads(),
3436
num_samples,
35-
num_chains;
36-
chain_type=Chains,
37+
num_chains,
3738
)
3839
# isequal to avoid comparing `missing`s in chain stats
40+
@test chn1 isa MCMCChains.Chains
41+
@test chn2 isa MCMCChains.Chains
3942
@test isequal(chn1.value, chn2.value)
4043
end
4144
end

0 commit comments

Comments
 (0)