Skip to content

Commit bda441b

Browse files
github-actions[bot]CompatHelper Juliadevmotiontorfjeldesunxd3
authored
CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compat) (#551)
* CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compat) * CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat) (#552) Co-authored-by: CompatHelper Julia <[email protected]> * Update to AbstractMCMC 5 * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update sampler.jl * CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat) (#553) * Fix for `rand` + replace overloads of `rand` with `rand_prior_true` for testing models (#541) * preserve context from model in `rand` * replace rand overloads in TestUtils with definitions of rand_prior_true so we can properly test rand * removed NamedTuple from signature of TestUtils.rand_prior_true * updated references to previous overloads of rand to now use rand_prior_true * test rand for DEMO_MODELS * formatting * fixed tests for rand for DEMO_MODELS * fixed linkning tests * added missing impl of rand_prior_true for demo_static_transformation * formatting * fixed rand_prior_true for demo_static_transformation * bump minor version as this will be breaking * bump patch version * fixed old usage of rand * Update test/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed another usage of rand --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Remove `tonamedtuple` (#547) * Remove dependencies to `tonamedtuple` * Remove `tonamedtuple`s * Minor version bump --------- Co-authored-by: Hong Ge <[email protected]> * CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat) --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: CompatHelper Julia <[email protected]> * bump AbstractPPL version to 0.7 * Update AbstractPPL test dependency * add `Random.AbstractRNG` * Update sampler.jl (#557) * Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent 2e940aa commit bda441b

File tree

5 files changed

+63
-37
lines changed

5 files changed

+63
-37
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2929
DynamicPPLMCMCChainsExt = ["MCMCChains"]
3030

3131
[compat]
32-
AbstractMCMC = "2, 3.0, 4"
33-
AbstractPPL = "0.6"
32+
AbstractMCMC = "5"
33+
AbstractPPL = "0.7"
3434
BangBang = "0.3"
3535
Bijectors = "0.13"
36-
ChainRulesCore = "0.9.7, 0.10, 1"
37-
ConstructionBase = "1.5.4"
36+
ChainRulesCore = "1"
3837
Compat = "4"
39-
Distributions = "0.23.8, 0.24, 0.25"
40-
DocStringExtensions = "0.8, 0.9"
38+
ConstructionBase = "1.5.4"
39+
Distributions = "0.25"
40+
DocStringExtensions = "0.9"
4141
LogDensityProblems = "2"
4242
MCMCChains = "6"
4343
MacroTools = "0.5.6"
4444
OrderedCollections = "1"
4545
Requires = "1"
46-
Setfield = "0.7.1, 0.8, 1"
46+
Setfield = "1"
4747
ZygoteRules = "0.2"
4848
julia = "1.6"
4949

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ function _check_varname_indexing(c::MCMCChains.Chains)
1414
error("Chains do not support indexing using $vn.")
1515
end
1616

17+
# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
18+
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
19+
if !haskey(chain.info, :samplerstate)
20+
throw(
21+
ArgumentError(
22+
"The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.",
23+
),
24+
)
25+
end
26+
return chain.info[:samplerstate]
27+
end
28+
1729
# A few methods needed.
1830
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
1931
return _has_varname_to_symbol(chain.info)

src/sampler.jl

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,31 @@ function default_varinfo(
8080
return VarInfo(rng, model, init_sampler, context)
8181
end
8282

83-
# initial step: general interface for resuming and
84-
function AbstractMCMC.step(
83+
function AbstractMCMC.sample(
8584
rng::Random.AbstractRNG,
8685
model::Model,
87-
spl::Sampler;
86+
sampler::Sampler,
87+
N::Integer;
88+
chain_type=default_chain_type(sampler),
8889
resume_from=nothing,
89-
init_params=nothing,
90+
initial_state=loadstate(resume_from),
9091
kwargs...,
9192
)
92-
if resume_from !== nothing
93-
state = loadstate(resume_from)
94-
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
95-
end
93+
return AbstractMCMC.mcmcsample(
94+
rng, model, sampler, N; chain_type, initial_state, kwargs...
95+
)
96+
end
9697

98+
# initial step: general interface for resuming and
99+
function AbstractMCMC.step(
100+
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
101+
)
97102
# Sample initial values.
98103
vi = default_varinfo(rng, model, spl)
99104

100105
# Update the parameters if provided.
101-
if init_params !== nothing
102-
vi = initialize_parameters!!(vi, init_params, spl, model)
106+
if initial_params !== nothing
107+
vi = initialize_parameters!!(vi, initial_params, spl, model)
103108

104109
# Update joint log probability.
105110
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
@@ -108,15 +113,24 @@ function AbstractMCMC.step(
108113
vi = last(evaluate!!(model, vi, DefaultContext()))
109114
end
110115

111-
return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
116+
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
112117
end
113118

114119
"""
115120
loadstate(data)
116121
117122
Load sampler state from `data`.
123+
124+
By default, `data` is returned.
125+
"""
126+
loadstate(data) = data
127+
128+
"""
129+
default_chaintype(sampler)
130+
131+
Default type of the chain of posterior samples from `sampler`.
118132
"""
119-
function loadstate end
133+
default_chain_type(sampler::Sampler) = Any
120134

121135
"""
122136
initialsampler(sampler::Sampler)
@@ -129,12 +143,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
129143
initialsampler(spl::Sampler) = SampleFromPrior()
130144

131145
function initialize_parameters!!(
132-
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
146+
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
133147
)
134-
@debug "Using passed-in initial variable values" init_params
148+
@debug "Using passed-in initial variable values" initial_params
135149

136150
# Flatten parameters.
137-
init_theta = mapreduce(vcat, init_params) do x
151+
init_theta = mapreduce(vcat, initial_params) do x
138152
vec([x;])
139153
end
140154

test/Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2222
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2323

2424
[compat]
25-
AbstractMCMC = "2.1, 3.0, 4"
26-
AbstractPPL = "0.6"
25+
AbstractMCMC = "5"
26+
AbstractPPL = "0.7"
2727
Bijectors = "0.13"
2828
Compat = "4.3.0"
2929
Distributions = "0.25"
3030
DistributionsAD = "0.6.3"
31-
Documenter = "0.26.1, 0.27, 1"
31+
Documenter = "1"
3232
ForwardDiff = "0.10.12"
3333
LogDensityProblems = "2"
34-
MCMCChains = "4.0.4, 5, 6"
34+
MCMCChains = "6.0.4"
3535
MacroTools = "0.5.5"
36-
Setfield = "0.7.1, 0.8, 1"
36+
Setfield = "1"
3737
StableRNGs = "1"
3838
Tracker = "0.2.23"
39-
Zygote = "0.5.4, 0.6"
39+
Zygote = "0.6"
4040
julia = "1.6"

test/sampler.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
model = coinflip()
8585
sampler = Sampler(alg)
8686
lptrue = logpdf(Binomial(25, 0.2), 10)
87-
chain = sample(model, sampler, 1; init_params=0.2, progress=false)
87+
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
8888
@test chain[1].metadata.p.vals == [0.2]
8989
@test getlogp(chain[1]) == lptrue
9090

@@ -95,7 +95,7 @@
9595
MCMCThreads(),
9696
1,
9797
10;
98-
init_params=fill(0.2, 10),
98+
initial_params=fill(0.2, 10),
9999
progress=false,
100100
)
101101
for c in chains
@@ -110,7 +110,7 @@
110110
end
111111
model = twovars()
112112
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
113-
chain = sample(model, sampler, 1; init_params=[4, -1], progress=false)
113+
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
114114
@test chain[1].metadata.s.vals == [4]
115115
@test chain[1].metadata.m.vals == [-1]
116116
@test getlogp(chain[1]) == lptrue
@@ -122,7 +122,7 @@
122122
MCMCThreads(),
123123
1,
124124
10;
125-
init_params=fill([4, -1], 10),
125+
initial_params=fill([4, -1], 10),
126126
progress=false,
127127
)
128128
for c in chains
@@ -132,7 +132,7 @@
132132
end
133133

134134
# set only m = -1
135-
chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false)
135+
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
136136
@test !ismissing(chain[1].metadata.s.vals[1])
137137
@test chain[1].metadata.m.vals == [-1]
138138

@@ -143,19 +143,19 @@
143143
MCMCThreads(),
144144
1,
145145
10;
146-
init_params=fill([missing, -1], 10),
146+
initial_params=fill([missing, -1], 10),
147147
progress=false,
148148
)
149149
for c in chains
150150
@test !ismissing(c[1].metadata.s.vals[1])
151151
@test c[1].metadata.m.vals == [-1]
152152
end
153153

154-
# specify `init_params=nothing`
154+
# specify `initial_params=nothing`
155155
Random.seed!(1234)
156156
chain1 = sample(model, sampler, 1; progress=false)
157157
Random.seed!(1234)
158-
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
158+
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
159159
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
160160
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals
161161

@@ -164,7 +164,7 @@
164164
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
165165
Random.seed!(1234)
166166
chains2 = sample(
167-
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
167+
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
168168
)
169169
for (c1, c2) in zip(chains1, chains2)
170170
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals

0 commit comments

Comments
 (0)