Skip to content

Commit bf32ed4

Browse files
authored
Merge branch 'main' into breaking
2 parents b3cc1e4 + 1397d69 commit bf32ed4

File tree

7 files changed

+231
-16
lines changed

7 files changed

+231
-16
lines changed

.github/workflows/Docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Documentation
33
on:
44
push:
55
branches:
6-
- master
6+
- main
77
tags: '*'
88
pull_request:
99

.github/workflows/Format.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ name: Format
33
on:
44
push:
55
branches:
6-
- master
6+
- main
77
pull_request:
8-
branches:
9-
- master
108
merge_group:
119
types: [checks_requested]
1210

.github/workflows/Tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Tests
33
on:
44
push:
55
branches:
6-
- master
6+
- main
77
pull_request:
88

99
# Cancel existing tests on the same PR if a new commit is added to a pull request

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Turing.jl
22

3-
[![Build Status](https://github.com/TuringLang/Turing.jl/workflows/Turing-CI/badge.svg)](https://github.com/TuringLang/Turing.jl/actions?query=workflow%3ATuring-CI+branch%3Amaster)
4-
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/Turing.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/Turing.jl?branch=master)
5-
[![codecov](https://codecov.io/gh/TuringLang/Turing.jl/branch/master/graph/badge.svg?token=OiUBsnDQqf)](https://codecov.io/gh/TuringLang/Turing.jl)
3+
[![Build Status](https://github.com/TuringLang/Turing.jl/workflows/Turing-CI/badge.svg)](https://github.com/TuringLang/Turing.jl/actions?query=workflow%3ATuring-CI+branch%3Amain)
4+
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/Turing.jl/badge.svg?branch=main)](https://coveralls.io/github/TuringLang/Turing.jl?branch=main)
5+
[![codecov](https://codecov.io/gh/TuringLang/Turing.jl/branch/main/graph/badge.svg?token=OiUBsnDQqf)](https://codecov.io/gh/TuringLang/Turing.jl)
66
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
77

88
## Getting Started

src/mcmc/gibbs.jl

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,75 @@ end
300300

301301
varinfo(state::GibbsState) = state.vi
302302

303-
function DynamicPPL.initialstep(
303+
"""
304+
Initialise a VarInfo for the Gibbs sampler.
305+
306+
This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to
307+
support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
308+
incompatible with step_warmup.
309+
"""
310+
function initial_varinfo(rng, model, spl, initial_params)
311+
vi = DynamicPPL.default_varinfo(rng, model, spl)
312+
313+
# Update the parameters if provided.
314+
if initial_params !== nothing
315+
vi = DynamicPPL.initialize_parameters!!(vi, initial_params, spl, model)
316+
317+
# Update joint log probability.
318+
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
319+
# and https://github.com/TuringLang/Turing.jl/issues/1563
320+
# to avoid that existing variables are resampled
321+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext()))
322+
end
323+
return vi
324+
end
325+
326+
function AbstractMCMC.step(
304327
rng::Random.AbstractRNG,
305328
model::DynamicPPL.Model,
306-
spl::DynamicPPL.Sampler{<:Gibbs},
307-
vi::DynamicPPL.AbstractVarInfo;
329+
spl::DynamicPPL.Sampler{<:Gibbs};
308330
initial_params=nothing,
309331
kwargs...,
310332
)
311333
alg = spl.alg
312334
varnames = alg.varnames
313335
samplers = alg.samplers
336+
vi = initial_varinfo(rng, model, spl, initial_params)
314337

315338
vi, states = gibbs_initialstep_recursive(
316-
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
339+
rng,
340+
model,
341+
AbstractMCMC.step,
342+
varnames,
343+
samplers,
344+
vi;
345+
initial_params=initial_params,
346+
kwargs...,
347+
)
348+
return Transition(model, vi), GibbsState(vi, states)
349+
end
350+
351+
function AbstractMCMC.step_warmup(
352+
rng::Random.AbstractRNG,
353+
model::DynamicPPL.Model,
354+
spl::DynamicPPL.Sampler{<:Gibbs};
355+
initial_params=nothing,
356+
kwargs...,
357+
)
358+
alg = spl.alg
359+
varnames = alg.varnames
360+
samplers = alg.samplers
361+
vi = initial_varinfo(rng, model, spl, initial_params)
362+
363+
vi, states = gibbs_initialstep_recursive(
364+
rng,
365+
model,
366+
AbstractMCMC.step_warmup,
367+
varnames,
368+
samplers,
369+
vi;
370+
initial_params=initial_params,
371+
kwargs...,
317372
)
318373
return Transition(model, vi), GibbsState(vi, states)
319374
end
@@ -322,9 +377,20 @@ end
322377
Take the first step of MCMC for the first component sampler, and call the same function
323378
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
324379
and a tuple of initial states for all component samplers.
380+
381+
The `step_function` argument should always be either AbstractMCMC.step or
382+
AbstractMCMC.step_warmup.
325383
"""
326384
function gibbs_initialstep_recursive(
327-
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
385+
rng,
386+
model,
387+
step_function::Function,
388+
varname_vecs,
389+
samplers,
390+
vi,
391+
states=();
392+
initial_params=nothing,
393+
kwargs...,
328394
)
329395
# End recursion
330396
if isempty(varname_vecs) && isempty(samplers)
@@ -345,7 +411,7 @@ function gibbs_initialstep_recursive(
345411
conditioned_model, context = make_conditional(model, varnames, vi)
346412

347413
# Take initial step with the current sampler.
348-
_, new_state = AbstractMCMC.step(
414+
_, new_state = step_function(
349415
rng,
350416
conditioned_model,
351417
sampler;
@@ -365,6 +431,7 @@ function gibbs_initialstep_recursive(
365431
return gibbs_initialstep_recursive(
366432
rng,
367433
model,
434+
step_function,
368435
varname_vecs_tail,
369436
samplers_tail,
370437
vi,
@@ -388,7 +455,29 @@ function AbstractMCMC.step(
388455
states = state.states
389456
@assert length(samplers) == length(state.states)
390457

391-
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
458+
vi, states = gibbs_step_recursive(
459+
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
460+
)
461+
return Transition(model, vi), GibbsState(vi, states)
462+
end
463+
464+
function AbstractMCMC.step_warmup(
465+
rng::Random.AbstractRNG,
466+
model::DynamicPPL.Model,
467+
spl::DynamicPPL.Sampler{<:Gibbs},
468+
state::GibbsState;
469+
kwargs...,
470+
)
471+
vi = varinfo(state)
472+
alg = spl.alg
473+
varnames = alg.varnames
474+
samplers = alg.samplers
475+
states = state.states
476+
@assert length(samplers) == length(state.states)
477+
478+
vi, states = gibbs_step_recursive(
479+
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
480+
)
392481
return Transition(model, vi), GibbsState(vi, states)
393482
end
394483

@@ -517,10 +606,14 @@ end
517606
"""
518607
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
519608
function on the tail, until there are no more samplers left.
609+
610+
The `step_function` argument should always be either AbstractMCMC.step or
611+
AbstractMCMC.step_warmup.
520612
"""
521613
function gibbs_step_recursive(
522614
rng::Random.AbstractRNG,
523615
model::DynamicPPL.Model,
616+
step_function::Function,
524617
varname_vecs,
525618
samplers,
526619
states,
@@ -554,7 +647,7 @@ function gibbs_step_recursive(
554647
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)
555648

556649
# Take a step with the local sampler.
557-
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))
650+
new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...))
558651

559652
new_vi_local = varinfo(new_state)
560653
# Merge the latest values for all the variables in the current sampler.
@@ -565,6 +658,7 @@ function gibbs_step_recursive(
565658
return gibbs_step_recursive(
566659
rng,
567660
model,
661+
step_function,
568662
varname_vecs_tail,
569663
samplers_tail,
570664
states_tail,

src/mcmc/repeat_sampler.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,30 @@ function AbstractMCMC.step(
5959
end
6060
return transition, state
6161
end
62+
63+
function AbstractMCMC.step_warmup(
64+
rng::Random.AbstractRNG,
65+
model::AbstractMCMC.AbstractModel,
66+
sampler::RepeatSampler;
67+
kwargs...,
68+
)
69+
return AbstractMCMC.step_warmup(rng, model, sampler.sampler; kwargs...)
70+
end
71+
72+
function AbstractMCMC.step_warmup(
73+
rng::Random.AbstractRNG,
74+
model::AbstractMCMC.AbstractModel,
75+
sampler::RepeatSampler,
76+
state;
77+
kwargs...,
78+
)
79+
transition, state = AbstractMCMC.step_warmup(
80+
rng, model, sampler.sampler, state; kwargs...
81+
)
82+
for _ in 2:(sampler.num_repeat)
83+
transition, state = AbstractMCMC.step_warmup(
84+
rng, model, sampler.sampler, state; kwargs...
85+
)
86+
end
87+
return transition, state
88+
end

test/mcmc/gibbs.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,102 @@ end
267267
@test chain1.value == chain2.value
268268
end
269269

270+
@testset "Gibbs warmup" begin
271+
# An inference algorithm, for testing purposes, that records how many warm-up steps
272+
# and how many non-warm-up steps haven been taken.
273+
mutable struct WarmupCounter <: Inference.InferenceAlgorithm
274+
warmup_init_count::Int
275+
non_warmup_init_count::Int
276+
warmup_count::Int
277+
non_warmup_count::Int
278+
279+
WarmupCounter() = new(0, 0, 0, 0)
280+
end
281+
282+
Turing.Inference.drop_space(wuc::WarmupCounter) = wuc
283+
Turing.Inference.getspace(::WarmupCounter) = ()
284+
Turing.Inference.isgibbscomponent(::WarmupCounter) = true
285+
286+
# A trivial state that holds nothing but a VarInfo, to be used with WarmupCounter.
287+
struct VarInfoState{T}
288+
vi::T
289+
end
290+
291+
Turing.Inference.varinfo(state::VarInfoState) = state.vi
292+
function Turing.Inference.setparams_varinfo!!(
293+
::DynamicPPL.Model,
294+
::DynamicPPL.Sampler,
295+
::VarInfoState,
296+
params::DynamicPPL.AbstractVarInfo,
297+
)
298+
return VarInfoState(params)
299+
end
300+
301+
function AbstractMCMC.step(
302+
::Random.AbstractRNG,
303+
model::DynamicPPL.Model,
304+
spl::DynamicPPL.Sampler{<:WarmupCounter};
305+
kwargs...,
306+
)
307+
spl.alg.non_warmup_init_count += 1
308+
return Turing.Inference.Transition(nothing, 0.0),
309+
VarInfoState(DynamicPPL.VarInfo(model))
310+
end
311+
312+
function AbstractMCMC.step_warmup(
313+
::Random.AbstractRNG,
314+
model::DynamicPPL.Model,
315+
spl::DynamicPPL.Sampler{<:WarmupCounter};
316+
kwargs...,
317+
)
318+
spl.alg.warmup_init_count += 1
319+
return Turing.Inference.Transition(nothing, 0.0),
320+
VarInfoState(DynamicPPL.VarInfo(model))
321+
end
322+
323+
function AbstractMCMC.step(
324+
::Random.AbstractRNG,
325+
::DynamicPPL.Model,
326+
spl::DynamicPPL.Sampler{<:WarmupCounter},
327+
s::VarInfoState;
328+
kwargs...,
329+
)
330+
spl.alg.non_warmup_count += 1
331+
return Turing.Inference.Transition(nothing, 0.0), s
332+
end
333+
334+
function AbstractMCMC.step_warmup(
335+
::Random.AbstractRNG,
336+
::DynamicPPL.Model,
337+
spl::DynamicPPL.Sampler{<:WarmupCounter},
338+
s::VarInfoState;
339+
kwargs...,
340+
)
341+
spl.alg.warmup_count += 1
342+
return Turing.Inference.Transition(nothing, 0.0), s
343+
end
344+
345+
@model f() = x ~ Normal()
346+
m = f()
347+
348+
num_samples = 10
349+
num_warmup = 3
350+
wuc = WarmupCounter()
351+
sample(m, Gibbs(:x => wuc), num_samples; num_warmup=num_warmup)
352+
@test wuc.warmup_init_count == 1
353+
@test wuc.non_warmup_init_count == 0
354+
@test wuc.warmup_count == num_warmup
355+
@test wuc.non_warmup_count == num_samples - 1
356+
357+
num_reps = 2
358+
wuc = WarmupCounter()
359+
sample(m, Gibbs(:x => RepeatSampler(wuc, num_reps)), num_samples; num_warmup=num_warmup)
360+
@test wuc.warmup_init_count == 1
361+
@test wuc.non_warmup_init_count == 0
362+
@test wuc.warmup_count == num_warmup * num_reps
363+
@test wuc.non_warmup_count == (num_samples - 1) * num_reps
364+
end
365+
270366
@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
271367
@info "Starting Gibbs tests with $adbackend"
272368

0 commit comments

Comments
 (0)