Skip to content

Commit 27aab23

Browse files
committed
Merge branch 'breaking' into mhauru/dppl-0.37
2 parents 8fdecc0 + eb2b7a7 commit 27aab23

File tree

4 files changed

+128
-5
lines changed

4 files changed

+128
-5
lines changed

HISTORY.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22

33
[...]
44

5+
# 0.39.9
6+
7+
Revert a bug introduced in 0.39.5 in the external sampler interface.
8+
For Turing 0.39, external samplers should define
9+
10+
```
11+
Turing.Inference.getparams(::DynamicPPL.Model, ::MySamplerTransition)
12+
```
13+
14+
rather than
15+
16+
```
17+
AbstractMCMC.getparams(::DynamicPPL.Model, ::MySamplerState)
18+
```
19+
20+
to obtain a vector of parameters from the model.
21+
22+
Note that this may change in future breaking releases.
23+
524
# 0.39.8
625

726
MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components.

src/mcmc/Inference.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,27 @@ end
206206
# Chain making utilities #
207207
##########################
208208

209-
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
210-
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
209+
# TODO(penelopeysm): Separate Turing.Inference.getparams (should only be
210+
# defined for AbstractVarInfo and Turing.Inference.Transition; returns varname
211+
# => value maps) from AbstractMCMC.getparams (defined for any sampler transition,
212+
# returns vector).
213+
"""
214+
Turing.Inference.getparams(model::DynamicPPL.Model, t::Any)
215+
216+
Return a vector of parameter values from the given sampler transition `t` (i.e.,
217+
the first return value of AbstractMCMC.step). By default, returns the `t.θ` field.
218+
219+
!!! note
220+
This method only needs to be implemented for external samplers. It will be
221+
removed in future releases and replaced with `AbstractMCMC.getparams`.
222+
"""
223+
getparams(::DynamicPPL.Model, t) = t.θ
224+
"""
225+
Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo)
226+
227+
Return a key-value map of parameters from the varinfo.
228+
"""
229+
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
211230
t = Transition(model, vi, nothing)
212231
return getparams(model, t)
213232
end

src/mcmc/external_sampler.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ When implementing a new `MySampler <: AbstractSampler`,
1515
In particular, it must implement:
1616
1717
- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl)
18-
- `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`: How to extract the parameters from the state returned by your sampler (i.e., the second return value of `step`).
18+
- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the parameters from the transition returned by your sampler (i.e., the first return value of `step`).
19+
There is a default implementation for this method, which is to return `external_transition.θ`.
20+
21+
!!! note
22+
In a future breaking release of Turing, this is likely to change to `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. `Turing.Inference.getparams` is technically an internal method, so the aim here is to unify the interface for samplers at a higher level.
1923
2024
There are a few more optional functions which you can implement to improve the integration with Turing.jl:
2125
@@ -153,7 +157,10 @@ function AbstractMCMC.step(
153157
)
154158
end
155159

156-
new_parameters = getparams(f.model, state_inner)
160+
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
161+
# The latter uses the state rather than the transition.
162+
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
163+
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
157164
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
158165
return (
159166
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)

test/mcmc/external_sampler.jl

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ExternalSamplerTests
22

3+
using ..Models: gdemo_default
34
using AbstractMCMC: AbstractMCMC
45
using AdvancedMH: AdvancedMH
56
using Distributions: sample
@@ -14,6 +15,83 @@ using Test: @test, @test_throws, @testset
1415
using Turing
1516
using Turing.Inference: AdvancedHMC
1617

18+
@testset "External sampler interface" begin
19+
# Turing declares an interface for external samplers (see docstring for
20+
# ExternalSampler). We should check that implementing this interface
21+
# and only this interface allows us to use the sampler in Turing.
22+
struct MyTransition{V<:AbstractVector}
23+
params::V
24+
end
25+
# Samplers need to implement `Turing.Inference.getparams`.
26+
Turing.Inference.getparams(::DynamicPPL.Model, t::MyTransition) = t.params
27+
# State doesn't matter (but we need to carry the params through to the next
28+
# iteration).
29+
struct MyState{V<:AbstractVector}
30+
params::V
31+
end
32+
33+
# externalsamplers must accept LogDensityModel inside their step function.
34+
# By default Turing gives the externalsampler a LDF constructed with
35+
# adtype=ForwardDiff, so we should expect that inside the sampler we can
36+
# call both `logdensity` and `logdensity_and_gradient`.
37+
#
38+
# The behaviour of this sampler is to simply calculate logp and its
39+
# gradient, and then return the same values.
40+
#
41+
# TODO: Do we also want to run ADTypeCheckContext to make sure that it is
42+
# indeed using the adtype provided from Turing?
43+
struct MySampler <: AbstractMCMC.AbstractSampler end
44+
function AbstractMCMC.step(
45+
rng::Random.AbstractRNG,
46+
model::AbstractMCMC.LogDensityModel,
47+
sampler::MySampler;
48+
initial_params::AbstractVector,
49+
kwargs...,
50+
)
51+
# Step 1
52+
ldf = model.logdensity
53+
lp = LogDensityProblems.logdensity(ldf, initial_params)
54+
@test lp isa Real
55+
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, initial_params)
56+
@test lp isa Real
57+
@test grad isa AbstractVector{<:Real}
58+
return MyTransition(initial_params), MyState(initial_params)
59+
end
60+
function AbstractMCMC.step(
61+
rng::Random.AbstractRNG,
62+
model::AbstractMCMC.LogDensityModel,
63+
sampler::MySampler,
64+
state::MyState;
65+
kwargs...,
66+
)
67+
# Step >= 1
68+
params = state.params
69+
ldf = model.logdensity
70+
lp = LogDensityProblems.logdensity(ldf, params)
71+
@test lp isa Real
72+
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
73+
@test lp isa Real
74+
@test grad isa AbstractVector{<:Real}
75+
return MyTransition(params), MyState(params)
76+
end
77+
78+
@model function test_external_sampler()
79+
a ~ Beta(2, 2)
80+
return b ~ Normal(a)
81+
end
82+
model = test_external_sampler()
83+
a, b = 0.5, 0.0
84+
85+
chn = sample(model, externalsampler(MySampler()), 10; initial_params=[a, b])
86+
@test chn isa MCMCChains.Chains
87+
@test all(chn[:a] .== a)
88+
@test all(chn[:b] .== b)
89+
expected_logpdf = logpdf(Beta(2, 2), a) + logpdf(Normal(a), b)
90+
@test all(chn[:lp] .== expected_logpdf)
91+
@test all(chn[:logprior] .== expected_logpdf)
92+
@test all(chn[:loglikelihood] .== 0.0)
93+
end
94+
1795
function initialize_nuts(model::DynamicPPL.Model)
1896
# Create a linked varinfo
1997
vi = DynamicPPL.VarInfo(model)
@@ -109,7 +187,7 @@ function test_initial_params(
109187
end
110188
end
111189

112-
@testset verbose = true "External samplers" begin
190+
@testset verbose = true "Implementation of externalsampler interface for known packages" begin
113191
@testset "AdvancedHMC.jl" begin
114192
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
115193
adtype = Turing.DEFAULT_ADTYPE

0 commit comments

Comments
 (0)