Skip to content

Commit e26820f

Browse files
authored
Fix externalsampler interface (#2640)
* Fix externalsampler interface * Fix getparams docstring * Improve getparams docstrings * Fix tests
1 parent d0510b1 commit e26820f

File tree

5 files changed

+126
-6
lines changed

5 files changed

+126
-6
lines changed

HISTORY.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# 0.39.9
2+
3+
Revert a bug introduced in 0.39.5 in the external sampler interface.
4+
For Turing 0.39, external samplers should define
5+
6+
```
7+
Turing.Inference.getparams(::DynamicPPL.Model, ::MySamplerTransition)`
8+
```
9+
10+
rather than
11+
12+
```
13+
AbstractMCMC.getparams(::DynamicPPL.Model, ::MySamplerState)
14+
```
15+
16+
to obtain a vector of parameters from the model.
17+
18+
Note that this may change in future breaking releases.
19+
120
# 0.39.8
221

322
MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.39.8"
3+
version = "0.39.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,26 @@ metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),)
161161
# Chain making utilities #
162162
##########################
163163

164+
# TODO(penelopeysm): Separate Turing.Inference.getparams (should only be
165+
# defined for AbstractVarInfo and Turing.Inference.Transition; returns varname
166+
# => value maps) from AbstractMCMC.getparams (defined for any sampler transition,
167+
# returns vector).
164168
"""
165-
getparams(model, t)
169+
Turing.Inference.getparams(model::Any, t::Any)
166170
167-
Return a named tuple of parameters.
171+
Return a vector of parameter values from the given sampler transition `t` (i.e.,
172+
the first return value of AbstractMCMC.step). By default, returns the `t.θ` field.
173+
174+
!!! note
175+
This method only needs to be implemented for external samplers. It will be
176+
removed in future releases and replaced with `AbstractMCMC.getparams`.
168177
"""
169178
getparams(model, t) = t.θ
179+
"""
180+
Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo)
181+
182+
Return a key-value map of parameters from the varinfo.
183+
"""
170184
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
171185
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
172186
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints

src/mcmc/external_sampler.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ When implementing a new `MySampler <: AbstractSampler`,
1515
In particular, it must implement:
1616
1717
- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl)
18-
- `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`: How to extract the parameters from the state returned by your sampler (i.e., the second return value of `step`).
18+
- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the parameters from the transition returned by your sampler (i.e., the first return value of `step`).
19+
There is a default implementation for this method, which is to return `external_transition.θ`.
20+
21+
!!! note
22+
In a future breaking release of Turing, this is likely to change to `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. `Turing.Inference.getparams` is technically an internal method, so the aim here is to unify the interface for samplers at a higher level.
1923
2024
There are a few more optional functions which you can implement to improve the integration with Turing.jl:
2125
@@ -119,7 +123,10 @@ function make_updated_varinfo(
119123
f::DynamicPPL.LogDensityFunction, external_transition, external_state
120124
)
121125
# Set the parameters.
122-
new_parameters = getparams(f.model, external_state)
126+
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
127+
# The latter uses the state rather than the transition.
128+
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
129+
new_parameters = getparams(f.model, external_transition)
123130
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
124131
# Set (or recalculate, if needed) the log density.
125132
new_logp = getlogp_external(external_transition, external_state)

test/mcmc/external_sampler.jl

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ExternalSamplerTests
22

3+
using ..Models: gdemo_default
34
using AbstractMCMC: AbstractMCMC
45
using AdvancedMH: AdvancedMH
56
using Distributions: sample
@@ -14,6 +15,85 @@ using Test: @test, @test_throws, @testset
1415
using Turing
1516
using Turing.Inference: AdvancedHMC
1617

18+
@testset "External sampler interface" begin
19+
# Turing declares an interface for external samplers (see docstring for
20+
# ExternalSampler). We should check that implementing this interface
21+
# and only this interface allows us to use the sampler in Turing.
22+
struct MyTransition{V<:AbstractVector}
23+
params::V
24+
end
25+
# Samplers need to implement `Turing.Inference.getparams`.
26+
Turing.Inference.getparams(::DynamicPPL.Model, t::MyTransition) = t.params
27+
# State doesn't matter (but we need to carry the params through to the next
28+
# iteration).
29+
struct MyState{V<:AbstractVector}
30+
params::V
31+
end
32+
33+
# externalsamplers must accept LogDensityModel inside their step function.
34+
# By default Turing gives the externalsampler a LDF constructed with
35+
# adtype=ForwardDiff, so we should expect that inside the sampler we can
36+
# call both `logdensity` and `logdensity_and_gradient`.
37+
#
38+
# The behaviour of this sampler is to simply calculate logp and its
39+
# gradient, and then return the same values.
40+
#
41+
# TODO: Do we also want to run ADTypeCheckContext to make sure that it is
42+
# indeed using the adtype provided from Turing?
43+
struct MySampler <: AbstractMCMC.AbstractSampler end
44+
function AbstractMCMC.step(
45+
rng::Random.AbstractRNG,
46+
model::AbstractMCMC.LogDensityModel,
47+
sampler::MySampler;
48+
initial_params::AbstractVector,
49+
kwargs...,
50+
)
51+
# Step 1
52+
ldf = model.logdensity
53+
lp = LogDensityProblems.logdensity(ldf, initial_params)
54+
@test lp isa Real
55+
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, initial_params)
56+
@test lp isa Real
57+
@test grad isa AbstractVector{<:Real}
58+
return MyTransition(initial_params), MyState(initial_params)
59+
end
60+
function AbstractMCMC.step(
61+
rng::Random.AbstractRNG,
62+
model::AbstractMCMC.LogDensityModel,
63+
sampler::MySampler,
64+
state::MyState;
65+
kwargs...,
66+
)
67+
# Step >= 1
68+
params = state.params
69+
ldf = model.logdensity
70+
lp = LogDensityProblems.logdensity(ldf, params)
71+
@test lp isa Real
72+
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
73+
@test lp isa Real
74+
@test grad isa AbstractVector{<:Real}
75+
return MyTransition(params), MyState(params)
76+
end
77+
78+
@model function test_external_sampler()
79+
a ~ Beta(2, 2)
80+
return b ~ Normal(a)
81+
end
82+
model = test_external_sampler()
83+
a, b = 0.5, 0.0
84+
85+
chn = sample(model, externalsampler(MySampler()), 10; initial_params=[a, b])
86+
@test chn isa MCMCChains.Chains
87+
@test all(chn[:a] .== a)
88+
@test all(chn[:b] .== b)
89+
# TODO: Uncomment this once Turing v0.40 is released. In that version, logpdf
90+
# will be recalculated correctly for external samplers.
91+
# expected_logpdf = logpdf(Beta(2, 2), a) + logpdf(Normal(a), b)
92+
# @test all(chn[:lp] .== expected_logpdf)
93+
# @test all(chn[:logprior] .== expected_logpdf)
94+
# @test all(chn[:loglikelihood] .== 0.0)
95+
end
96+
1797
function initialize_nuts(model::DynamicPPL.Model)
1898
# Create a linked varinfo
1999
vi = DynamicPPL.VarInfo(model)
@@ -107,7 +187,7 @@ function test_initial_params(
107187
end
108188
end
109189

110-
@testset verbose = true "External samplers" begin
190+
@testset verbose = true "Implementation of externalsampler interface for known packages" begin
111191
@testset "AdvancedHMC.jl" begin
112192
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
113193
adtype = Turing.DEFAULT_ADTYPE

0 commit comments

Comments
 (0)