Skip to content

Commit f9142a6

Browse files
committed
Added testing of warmup steps
1 parent 3b4f6db commit f9142a6

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

test/sample.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,45 @@
575575
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
576576
end
577577

578+
@testset "Warm-up steps" begin
579+
# Create a chain and discard initial samples.
580+
Random.seed!(1234)
581+
N = 100
582+
num_warmup = 50
583+
584+
# Everything should be discarded here.
585+
chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup)
586+
@test length(chain) == N
587+
@test !ismissing(chain[1].a)
588+
589+
# Repeat sampling without discarding initial samples.
590+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
591+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
592+
Random.seed!(1234)
593+
ref_chain = sample(
594+
MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6"
595+
)
596+
@test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N)
597+
@test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N)
598+
599+
# Some other stuff.
600+
Random.seed!(1234)
601+
discard_initial = 10
602+
chain_warmup = sample(
603+
MyModel(),
604+
MySampler(),
605+
N;
606+
num_warmup=num_warmup,
607+
discard_initial=discard_initial,
608+
)
609+
@test length(chain_warmup) == N
610+
@test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N)
611+
# Check that the first `num_warmup - discard_initial` samples are warmup samples.
612+
@test all(
613+
chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N
614+
)
615+
end
616+
578617
@testset "Thin chain by a factor of `thinning`" begin
579618
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
580619
Random.seed!(100)

test/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end
33
struct MySample{A,B}
44
a::A
55
b::B
6+
is_warmup::Bool
67
end
78

9+
MySample(a, b) = MySample(a, b, false)
10+
811
struct MySampler <: AbstractMCMC.AbstractSampler end
912
struct AnotherSampler <: AbstractMCMC.AbstractSampler end
1013

@@ -16,6 +19,21 @@ end
1619

1720
MyChain(a, b) = MyChain(a, b, NamedTuple())
1821

22+
function AbstractMCMC.step_warmup(
23+
rng::AbstractRNG,
24+
model::MyModel,
25+
sampler::MySampler,
26+
state::Union{Nothing,Integer}=nothing;
27+
loggers=false,
28+
initial_params=nothing,
29+
kwargs...,
30+
)
31+
transition, state = AbstractMCMC.step(
32+
rng, model, sampler, state; loggers, initial_params, kwargs...
33+
)
34+
return MySample(transition.a, transition.b, true), state
35+
end
36+
1937
function AbstractMCMC.step(
2038
rng::AbstractRNG,
2139
model::MyModel,

0 commit comments

Comments
 (0)