Skip to content

Commit add8870

Browse files
authored
Update logp in varinfo when external samplers are used (#2616)
* Update logp in varinfo when external samplers are used * Add tests * More tests * Bump patch, changelog * Only do one extra model evaluation instead of two * Don't run the same test 13 times * Fix tests * Simplify interface * isapprox for floats... * Update comment / changelog
1 parent 1aa95ac commit add8870

File tree

6 files changed

+104
-24
lines changed

6 files changed

+104
-24
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# 0.39.5
2+
3+
Fixed a bug where sampling with an `externalsampler` would not set the log probability density inside the resulting chain.
4+
Note that there are still potentially bugs with the log-Jacobian term not being correctly included.
5+
A fix is being worked on.
6+
17
# 0.39.4
28

39
Bumped compatibility of AbstractPPL to include 0.12.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.39.4"
3+
version = "0.39.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/external_sampler.jl

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@ The `Unconstrained` type-parameter is to indicate whether the sampler requires u
77
88
# Fields
99
$(TYPEDFIELDS)
10+
11+
# Turing.jl's interface for external samplers
12+
13+
When implementing a new `MySampler <: AbstractSampler`,
14+
`MySampler` must first and foremost conform to the `AbstractMCMC` interface to work with Turing.jl's `externalsampler` function.
15+
In particular, it must implement:
16+
17+
- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl)
18+
- `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`: How to extract the parameters from the state returned by your sampler (i.e., the second return value of `step`).
19+
20+
There are a few more optional functions which you can implement to improve the integration with Turing.jl:
21+
22+
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
23+
24+
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space.
25+
26+
- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method.
1027
"""
1128
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
1229
InferenceAlgorithm
@@ -68,30 +85,28 @@ function externalsampler(
6885
return ExternalSampler(sampler, adtype, Val(unconstrained))
6986
end
7087

71-
struct TuringState{S,M,V,C}
88+
"""
89+
getlogp_external(external_transition, external_state)
90+
91+
Get the log probability density associated with the external sampler's
92+
transition and state. Returns `missing` by default; in this case, an extra
93+
model evaluation will be needed to calculate the correct log density.
94+
"""
95+
getlogp_external(::Any, ::Any) = missing
96+
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
97+
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density
98+
99+
struct TuringState{S,V1<:AbstractVarInfo,M,V,C}
72100
state::S
101+
# Note that this varinfo has the correct parameters and logp obtained from
102+
# the state, whereas `ldf.varinfo` will in general have junk inside it.
103+
varinfo::V1
73104
ldf::DynamicPPL.LogDensityFunction{M,V,C}
74105
end
75106

76-
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
77-
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
78-
# TODO: We should probably rename this `getparams` since it returns something
79-
# very different from `Turing.Inference.getparams`.
80-
θ = getparams(f.model, transition)
81-
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
82-
return Transition(f.model, varinfo, transition)
83-
end
84-
85-
function varinfo(state::TuringState)
86-
θ = getparams(state.ldf.model, state.state)
87-
# TODO: Do we need to link here first?
88-
return DynamicPPL.unflatten(state.ldf.varinfo, θ)
89-
end
107+
varinfo(state::TuringState) = state.varinfo
90108
varinfo(state::AbstractVarInfo) = state
91109

92-
# NOTE: Only thing that depends on the underlying sampler.
93-
# Something similar should be part of AbstractMCMC at some point:
94-
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
95110
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
96111
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
97112
return getparams(model, state.transition)
@@ -100,6 +115,21 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
100115

101116
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
102117

118+
function make_updated_varinfo(
119+
f::DynamicPPL.LogDensityFunction, external_transition, external_state
120+
)
121+
# Set the parameters.
122+
new_parameters = getparams(f.model, external_state)
123+
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
124+
# Set (or recalculate, if needed) the log density.
125+
new_logp = getlogp_external(external_transition, external_state)
126+
return if ismissing(new_logp)
127+
last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context))
128+
else
129+
DynamicPPL.setlogp!!(new_varinfo, new_logp)
130+
end
131+
end
132+
103133
# TODO: Do we also support `resume`, etc?
104134
function AbstractMCMC.step(
105135
rng::Random.AbstractRNG,
@@ -143,8 +173,15 @@ function AbstractMCMC.step(
143173
kwargs...,
144174
)
145175
end
176+
177+
# Get the parameters and log density, and set them in the varinfo.
178+
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
179+
146180
# Update the `state`
147-
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
181+
return (
182+
Transition(f.model, new_varinfo, transition_inner),
183+
TuringState(state_inner, new_varinfo, f),
184+
)
148185
end
149186

150187
function AbstractMCMC.step(
@@ -157,11 +194,17 @@ function AbstractMCMC.step(
157194
sampler = sampler_wrapper.alg.sampler
158195
f = state.ldf
159196

160-
# Then just call `AdvancedHMC.step` with the right arguments.
197+
# Then just call `AdvancedMCMC.step` with the right arguments.
161198
transition_inner, state_inner = AbstractMCMC.step(
162199
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
163200
)
164201

202+
# Get the parameters and log density, and set them in the varinfo.
203+
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
204+
165205
# Update the `state`
166-
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
206+
return (
207+
Transition(f.model, new_varinfo, transition_inner),
208+
TuringState(state_inner, new_varinfo, f),
209+
)
167210
end

src/mcmc/gibbs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ isgibbscomponent(::PG) = true
1818
isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)
1919

2020
isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
21-
isgibbscomponent(::AdvancedHMC.HMC) = true
21+
isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
2222
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
23+
isgibbscomponent(spl) = false
2324

2425
function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
2526
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
@@ -561,7 +562,7 @@ function setparams_varinfo!!(
561562
new_inner_state = setparams_varinfo!!(
562563
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
563564
)
564-
return TuringState(new_inner_state, logdensity)
565+
return TuringState(new_inner_state, params, logdensity)
565566
end
566567

567568
function setparams_varinfo!!(

test/mcmc/external_sampler.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ end
141141
)
142142
end
143143
end
144+
145+
@testset "logp is set correctly" begin
146+
@model logp_check() = x ~ Normal()
147+
model = logp_check()
148+
sampler = initialize_nuts(model)
149+
sampler_ext = externalsampler(
150+
sampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained=true
151+
)
152+
chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100)
153+
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
154+
end
144155
end
145156

146157
@testset "AdvancedMH.jl" begin
@@ -167,7 +178,17 @@ end
167178
)
168179
end
169180
end
181+
182+
@testset "logp is set correctly" begin
183+
@model logp_check() = x ~ Normal()
184+
model = logp_check()
185+
sampler = initialize_mh_rw(model)
186+
sampler_ext = externalsampler(sampler; unconstrained=true)
187+
chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100)
188+
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
189+
end
170190
end
191+
171192
# NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls
172193
# it with `NamedTuple` instead of `AbstractVector`.
173194
# @testset "MH with prior proposal" begin

test/mcmc/gibbs.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,14 @@ end
825825
end
826826

827827
@testset "externalsampler" begin
828+
function check_logp_correct(sampler)
829+
@testset "logp is set correctly" begin
830+
@model logp_check() = x ~ Normal()
831+
chn = sample(logp_check(), Gibbs(@varname(x) => sampler), 100)
832+
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
833+
end
834+
end
835+
828836
@model function demo_gibbs_external()
829837
m1 ~ Normal()
830838
m2 ~ Normal()
@@ -851,6 +859,7 @@ end
851859
model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
852860
)
853861
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
862+
check_logp_correct(sampler_inner)
854863
end
855864
end
856865

0 commit comments

Comments
 (0)