Skip to content

Commit 7b0d12b

Browse files
committed
Introduce RepeatSampler
1 parent f9ed562 commit 7b0d12b

File tree

8 files changed

+146
-8
lines changed

8 files changed

+146
-8
lines changed

HISTORY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ may be accidental breakage that we haven't anticipated. Please report any you fi
1212

1313
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.
1414

15-
Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`.
15+
Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`.
1616

1717
# Release 0.35.0
1818

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export @model, # modelling
9595
SMC,
9696
CSMC,
9797
PG,
98+
RepeatSampler,
9899
vi, # variational inference
99100
ADVI,
100101
sample, # inference

src/mcmc/Inference.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ export InferenceAlgorithm,
7474
SMC,
7575
CSMC,
7676
PG,
77+
RepeatSampler,
7778
Prior,
7879
assume,
7980
dot_assume,
@@ -100,6 +101,12 @@ Return an `InferenceAlgorithm` like `alg`, but with all space information remove
100101
"""
101102
function drop_space end
102103

104+
function drop_space(sampler::Sampler)
105+
return Sampler(drop_space(sampler.alg), sampler.selector)
106+
end
107+
108+
include("repeat_sampler.jl")
109+
103110
"""
104111
ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained}
105112
@@ -348,7 +355,7 @@ end
348355
function AbstractMCMC.sample(
349356
rng::AbstractRNG,
350357
model::AbstractModel,
351-
sampler::Sampler{<:InferenceAlgorithm},
358+
sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
352359
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
353360
N::Integer,
354361
n_chains::Integer;
@@ -460,7 +467,7 @@ getlogevidence(transitions, sampler, state) = missing
460467
function AbstractMCMC.bundle_samples(
461468
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
462469
model::AbstractModel,
463-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
470+
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
464471
state,
465472
chain_type::Type{MCMCChains.Chains};
466473
save_state=false,
@@ -523,7 +530,7 @@ end
523530
function AbstractMCMC.bundle_samples(
524531
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
525532
model::AbstractModel,
526-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
533+
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
527534
state,
528535
chain_type::Type{Vector{NamedTuple}};
529536
kwargs...,

src/mcmc/gibbs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ isgibbscomponent(::NUTS) = true
1515
isgibbscomponent(::MH) = true
1616
isgibbscomponent(::PG) = true
1717

18+
isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)
19+
1820
isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
1921
isgibbscomponent(::AdvancedHMC.HMC) = true
2022
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
@@ -364,7 +366,7 @@ function Gibbs(algs::InferenceAlgorithm...)
364366
"`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " *
365367
"Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " *
366368
"counts for different subsamplers, use e.g. " *
367-
"`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`"
369+
"`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`"
368370
)
369371
Base.depwarn(msg, :Gibbs)
370372
return Gibbs(varnames, map(wrap_algorithm_maybe drop_space, algs))

src/mcmc/repeat_sampler.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
RepeatSampler <: AbstractMCMC.AbstractSampler
3+
4+
A `RepeatSampler` is a container for a sampler and a number of times to repeat it.
5+
6+
# Fields
7+
$(FIELDS)
8+
9+
# Examples
10+
```julia
11+
repeated_sampler = RepeatSampler(sampler, 10)
12+
AbstractMCMC.step(rng, model, repeated_sampler) # take 10 steps of `sampler`
13+
```
14+
"""
15+
struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSampler
16+
"The sampler to repeat"
17+
sampler::S
18+
"The number of times to repeat the sampler"
19+
num_repeat::Int
20+
21+
function RepeatSampler(sampler::S, num_repeat::Int) where {S}
22+
@assert num_repeat > 0
23+
return new{S}(sampler, num_repeat)
24+
end
25+
end
26+
27+
function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int)
28+
return RepeatSampler(Sampler(alg), num_repeat)
29+
end
30+
31+
drop_space(rs::RepeatSampler) = RepeatSampler(drop_space(rs.sampler), rs.num_repeat)
32+
getADType(spl::RepeatSampler) = getADType(spl.sampler)
33+
DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler)
34+
DynamicPPL.getspace(spl::RepeatSampler) = getspace(spl.sampler)
35+
DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler)
36+
37+
function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params)
38+
return setparams_varinfo!!(model, sampler.sampler, state, params)
39+
end
40+
41+
function AbstractMCMC.step(
42+
rng::Random.AbstractRNG,
43+
model::AbstractMCMC.AbstractModel,
44+
sampler::RepeatSampler;
45+
kwargs...,
46+
)
47+
return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...)
48+
end
49+
50+
function AbstractMCMC.step(
51+
rng::Random.AbstractRNG,
52+
model::AbstractMCMC.AbstractModel,
53+
sampler::RepeatSampler,
54+
state;
55+
kwargs...,
56+
)
57+
transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...)
58+
for _ in 2:(sampler.num_repeat)
59+
transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...)
60+
end
61+
return transition, state
62+
end

test/mcmc/gibbs.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,18 @@ end
251251
)
252252
end
253253

254+
@testset "Equivalence of RepeatSampler and repeating Sampler" begin
255+
sampler1 = Gibbs(@varname(s) => RepeatSampler(MH(), 3), @varname(m) => ESS())
256+
sampler2 = Gibbs(
257+
@varname(s) => MH(), @varname(s) => MH(), @varname(s) => MH(), @varname(m) => ESS()
258+
)
259+
Random.seed!(23)
260+
chain1 = sample(gdemo_default, sampler1, 10)
261+
Random.seed!(23)
262+
chain2 = sample(gdemo_default, sampler1, 10)
263+
@test chain1.value == chain2.value
264+
end
265+
254266
@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
255267
@testset "Deprecated Gibbs constructors" begin
256268
N = 10
@@ -302,7 +314,12 @@ end
302314
vnm = @varname(m)
303315
Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg)
304316
end
305-
for s in (s1, s2, s3, s4, s5, s6, s7, s8)
317+
# Same thing but using RepeatSampler.
318+
s9 = Gibbs(
319+
@varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3),
320+
@varname(m) => RepeatSampler(PG(10), 2),
321+
)
322+
for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9)
306323
@test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs"
307324
end
308325

@@ -314,6 +331,7 @@ end
314331
sample(gdemo_default, s6, N)
315332
sample(gdemo_default, s7, N)
316333
sample(gdemo_default, s8, N)
334+
sample(gdemo_default, s9, N)
317335

318336
g = Turing.Sampler(s3, gdemo_default)
319337
@test sample(gdemo_default, g, N) isa MCMCChains.Chains
@@ -355,7 +373,7 @@ end
355373
@varname(s) => MH(),
356374
(@varname(s), @varname(m)) => MH(),
357375
@varname(m) => ESS(),
358-
@varname(s) => MH(),
376+
@varname(s) => RepeatSampler(MH(), 3),
359377
@varname(m) => HMC(0.2, 4; adtype=adbackend),
360378
(@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend),
361379
)
@@ -367,7 +385,7 @@ end
367385
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
368386
(@varname(z1), @varname(z2)) => PG(15),
369387
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
370-
(@varname(z3), @varname(z4)) => PG(15),
388+
(@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2),
371389
(@varname(mu1)) => ESS(),
372390
(@varname(mu2)) => ESS(),
373391
(@varname(z1), @varname(z2)) => PG(15),

test/mcmc/repeat_sampler.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
module HMCTests
2+
3+
using ..Models: gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
5+
using ..NumericalTests: check_gdemo, check_numerical
6+
import ..ADUtils
7+
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
8+
import DynamicPPL
9+
using DynamicPPL: Sampler
10+
import ForwardDiff
11+
using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
12+
import ReverseDiff
13+
using LinearAlgebra: I, dot, vec
14+
import Random
15+
using StableRNGs: StableRNG
16+
using StatsFuns: logistic
17+
import Mooncake
18+
using Test: @test, @test_logs, @testset, @test_throws
19+
using Turing
20+
21+
# RepeatedSampler only really makes sense as a component sampler of Gibbs.
22+
# Here we just check that running it by itself is equivalent to thinning.
23+
@testset "RepeatedSampler" begin
24+
num_repeats = 17
25+
num_samples = 10
26+
num_chains = 2
27+
28+
rng = StableRNG(0)
29+
for sampler in [MH(), Sampler(HMC(0.01, 4))]
30+
chn1 = sample(
31+
copy(rng),
32+
gdemo_default,
33+
sampler,
34+
MCMCThreads(),
35+
num_samples,
36+
num_chains;
37+
thinning=num_repeats,
38+
)
39+
repeat_sampler = RepeatSampler(sampler, num_repeats)
40+
chn2 = sample(
41+
copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains
42+
)
43+
@test chn1.value == chn2.value
44+
end
45+
end
46+
47+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ end
6060
@timeit_include("mcmc/abstractmcmc.jl")
6161
@timeit_include("mcmc/mh.jl")
6262
@timeit_include("ext/dynamichmc.jl")
63+
@timeit_include("mcmc/repeat_sampler.jl")
6364
end
6465

6566
@testset "variational algorithms" begin

0 commit comments

Comments
 (0)