Skip to content

Commit 6ac3922

Browse files
devmotiontorfjelde
andauthored
Update to AbstractMCMC 2 (#150)
Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent cf95183 commit 6ac3922

23 files changed

+856
-1093
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.8"
3+
version = "0.10.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -12,7 +12,7 @@ NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313

1414
[compat]
15-
AbstractMCMC = "1"
15+
AbstractMCMC = "2"
1616
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
1717
ChainRulesCore = "0.9.7"
1818
Distributions = "0.23.8"

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ export AbstractVarInfo,
6161
Sample,
6262
init,
6363
vectorize,
64-
set_resume!,
6564
# Model
6665
Model,
6766
getmissings,
@@ -122,6 +121,4 @@ include("prob_macro.jl")
122121
include("compat/ad.jl")
123122
include("loglikelihoods.jl")
124123

125-
include("deprecations.jl")
126-
127124
end # module

src/deprecations.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/model.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,6 @@ See also: [`evaluate_threadsafe`](@ref)
124124
"""
125125
function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
126126
resetlogp!(varinfo)
127-
if has_eval_num(sampler)
128-
sampler.state.eval_num += 1
129-
end
130127
return _evaluate(rng, model, varinfo, sampler, context)
131128
end
132129

@@ -143,9 +140,6 @@ See also: [`evaluate_threadunsafe`](@ref)
143140
"""
144141
function evaluate_threadsafe(rng, model, varinfo, sampler, context)
145142
resetlogp!(varinfo)
146-
if has_eval_num(sampler)
147-
sampler.state.eval_num += 1
148-
end
149143
wrapper = ThreadSafeVarInfo(varinfo)
150144
result = _evaluate(rng, model, wrapper, sampler, context)
151145
setlogp!(varinfo, getlogp(wrapper))

src/sampler.jl

Lines changed: 105 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
2+
# That would let us use all defaults for Sampler, combine it with other samplers etc.
13
"""
24
Robust initialization method for model parameters in Hamiltonian samplers.
35
"""
@@ -17,55 +19,123 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
1719
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
1820
end
1921

20-
"""
21-
has_eval_num(spl::AbstractSampler)
22-
23-
Check whether `spl` has a field called `eval_num` in its state variables or not.
24-
"""
25-
has_eval_num(spl::SampleFromUniform) = false
26-
has_eval_num(spl::SampleFromPrior) = false
27-
has_eval_num(spl::AbstractSampler) = :eval_num in fieldnames(typeof(spl.state))
28-
29-
"""
30-
An abstract type that mutable sampler state structs inherit from.
31-
"""
32-
abstract type AbstractSamplerState end
33-
3422
"""
3523
Sampler{T}
3624
37-
Generic interface for implementing inference algorithms.
38-
An implementation of an algorithm should include the following:
39-
40-
1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
41-
2. A method of `sample` function that produces results of inference, which is where actual inference happens.
25+
Generic sampler type for inference algorithms of type `T` in DynamicPPL.
4226
43-
DynamicPPL translates models to chunks that call the modelling functions at specified points.
44-
The dispatch is based on the value of a `sampler` variable.
45-
To include a new inference algorithm implements the requirements mentioned above in a separate file,
46-
then include that file at the end of this one.
27+
`Sampler` should implement the AbstractMCMC interface, and in particular
28+
[`AbstractMCMC.step`](@ref). A default implementation of the initial sampling step is
29+
provided that supports resuming sampling from a previous state and setting initial
30+
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
31+
for loading previous states and actually performing the initial sampling step,
32+
respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
33+
that specifies how the initial parameter values are sampled if they are not provided.
34+
By default, values are sampled from the prior.
4735
"""
48-
mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler
49-
alg :: T
50-
info :: Dict{Symbol, Any} # sampler infomation
51-
selector :: Selector
52-
state :: S
36+
struct Sampler{T} <: AbstractSampler
37+
alg::T
38+
selector::Selector # Can we remove it?
39+
# TODO: add space such that we can integrate existing external samplers in DynamicPPL
5340
end
5441
Sampler(alg) = Sampler(alg, Selector())
5542
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
56-
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
43+
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5744

5845
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
59-
60-
function AbstractMCMC.step!(
46+
function AbstractMCMC.step(
6147
rng::Random.AbstractRNG,
6248
model::Model,
6349
sampler::Union{SampleFromUniform,SampleFromPrior},
64-
::Integer,
65-
transition;
50+
state = nothing;
6651
kwargs...
6752
)
6853
vi = VarInfo()
69-
model(vi, sampler)
70-
return vi
54+
model(rng, vi, sampler)
55+
return vi, nothing
56+
end
57+
58+
# initial step: general interface for resuming and
59+
function AbstractMCMC.step(
60+
rng::Random.AbstractRNG,
61+
model::Model,
62+
spl::Sampler;
63+
resume_from = nothing,
64+
kwargs...
65+
)
66+
if resume_from !== nothing
67+
state = loadstate(resume_from)
68+
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
69+
end
70+
71+
# Sample initial values.
72+
_spl = initialsampler(spl)
73+
vi = VarInfo(rng, model, _spl)
74+
75+
# Update the parameters if provided.
76+
if haskey(kwargs, :init_params)
77+
initialize_parameters!(vi, kwargs[:init_params], spl)
78+
79+
# Update joint log probability.
80+
model(rng, vi, _spl)
81+
end
82+
83+
return initialstep(rng, model, spl, vi; kwargs...)
84+
end
85+
86+
"""
87+
loadstate(data)
88+
89+
Load sampler state from `data`.
90+
"""
91+
function loadstate end
92+
93+
"""
94+
initialsampler(sampler::Sampler)
95+
96+
Return the sampler that is used for generating the initial parameters when sampling with
97+
`sampler`.
98+
99+
By default, it returns an instance of [`SampleFromPrior`](@ref).
100+
"""
101+
initialsampler(spl::Sampler) = SampleFromPrior()
102+
103+
function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler)
104+
@debug "Using passed-in initial variable values" init_params
105+
106+
# Flatten parameters.
107+
init_theta = mapreduce(vcat, init_params) do x
108+
vec([x;])
109+
end
110+
111+
# Get all values.
112+
linked = islinked(vi, spl)
113+
linked && invlink!(vi, spl)
114+
theta = vi[spl]
115+
length(theta) == length(init_theta_flat) ||
116+
error("Provided initial value doesn't match the dimension of the model")
117+
118+
# Update values that are provided.
119+
for i in 1:length(init_theta)
120+
x = init_theta[i]
121+
if x !== missing
122+
theta[i] = x
123+
end
124+
end
125+
126+
# Update in `vi`.
127+
vi[spl] = theta
128+
linked && link!(vi, spl)
129+
130+
return
71131
end
132+
133+
"""
134+
initialstep(rng, model, sampler, varinfo; kwargs...)
135+
136+
Perform the initial sampling step of the `sampler` for the `model`.
137+
138+
The `varinfo` contains the initial samples, which can be provided by the user or
139+
sampled randomly.
140+
"""
141+
function initialstep end

src/varinfo.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,38 @@ end
105105
const UntypedVarInfo = VarInfo{<:Metadata}
106106
const TypedVarInfo = VarInfo{<:NamedTuple}
107107

108-
function VarInfo(model::Model, ctx = DefaultContext())
109-
vi = VarInfo()
110-
model(vi, SampleFromPrior(), ctx)
111-
return TypedVarInfo(vi)
112-
end
113-
114108
function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
115109
new_vi = deepcopy(old_vi)
116110
new_vi[spl] = x
117111
return new_vi
118112
end
113+
119114
function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
120115
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
121116
VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)))
122117
end
118+
119+
function VarInfo(
120+
rng::Random.AbstractRNG,
121+
model::Model,
122+
sampler::AbstractSampler = SampleFromPrior(),
123+
context::AbstractContext = DefaultContext(),
124+
)
125+
varinfo = VarInfo()
126+
model(rng, varinfo, sampler, context)
127+
return TypedVarInfo(varinfo)
128+
end
129+
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
130+
131+
# without AbstractSampler
132+
function VarInfo(
133+
rng::Random.AbstractRNG,
134+
model::Model,
135+
context::AbstractContext,
136+
)
137+
return VarInfo(rng, model, SampleFromPrior(), context)
138+
end
139+
123140
@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space}
124141
exprs = []
125142
offset = :(0)
@@ -1000,7 +1017,6 @@ from a distribution `dist` to `VarInfo` `vi`.
10001017
The sampler is passed here to invalidate its cache where defined.
10011018
"""
10021019
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler)
1003-
spl.info[:cache_updated] = CACHERESET
10041020
return push!(vi, vn, r, dist, spl.selector)
10051021
end
10061022
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)

test/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
44
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
55
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
6+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
78
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -31,15 +32,16 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3132
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3233

3334
[compat]
34-
AbstractMCMC = "1.0.1"
35+
AbstractMCMC = "2.1"
3536
AdvancedHMC = "0.2.25"
36-
AdvancedMH = "0.5.1"
37+
AdvancedMH = "0.5.2"
3738
AdvancedVI = "0.1"
39+
BangBang = "0.3"
3840
Bijectors = "0.8.2"
3941
Distributions = "0.23.8"
4042
DistributionsAD = "0.6.3"
4143
DocStringExtensions = "0.8.2"
42-
EllipticalSliceSampling = "0.2.2"
44+
EllipticalSliceSampling = "0.3"
4345
ForwardDiff = "0.10.12"
4446
Libtask = "0.4.1, 0.5"
4547
LogDensityProblems = "0.10.3"

0 commit comments

Comments
 (0)