Skip to content

Commit 3e8d97f

Browse files
devmotiontorfjelde
andauthored
Fix MLE with ConditionContext (#2022)
* Fix MLE with ConditionContext Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update ModeEstimation.jl * Update tests * Improve tests Improve tests Co-authored-by: Tor Erlend Fjelde <[email protected]> --------- Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent a0b8999 commit 3e8d97f

File tree

3 files changed

+147
-50
lines changed

3 files changed

+147
-50
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.26.1"
3+
version = "0.26.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/modes/ModeEstimation.jl

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,41 +39,48 @@ intended to allow an optimizer to sample in R^n freely.
3939
"""
4040
struct OptimizationContext{C<:AbstractContext} <: AbstractContext
4141
context::C
42+
43+
function OptimizationContext{C}(context::C) where {C<:AbstractContext}
44+
if !(context isa Union{DefaultContext,LikelihoodContext})
45+
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`"))
46+
end
47+
return new{C}(context)
48+
end
4249
end
4350

44-
DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsParent()
45-
DynamicPPL.childcontext(context::OptimizationContext) = context.context
46-
DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child)
51+
OptimizationContext(context::AbstractContext) = OptimizationContext{typeof(context)}(context)
4752

48-
# assume
49-
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, dist, vn, vi)
50-
r = vi[vn, dist]
51-
return r, 0, vi
52-
end
53+
DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()
5354

55+
# assume
5456
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
5557
r = vi[vn, dist]
56-
return r, Distributions.logpdf(dist, r), vi
58+
lp = if ctx.context isa DefaultContext
59+
# MAP
60+
Distributions.logpdf(dist, r)
61+
else
62+
# MLE
63+
0
64+
end
65+
return r, lp, vi
5766
end
5867

5968
# dot assume
60-
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, right, left, vns, vi)
61-
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
62-
# affect anything.
63-
# TODO: Stop using `get_and_set_val!`.
64-
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
65-
return r, 0, vi
66-
end
67-
6869
_loglikelihood(dist::Distribution, x) = loglikelihood(dist, x)
6970
_loglikelihood(dists::AbstractArray{<:Distribution}, x) = loglikelihood(arraydist(dists), x)
70-
7171
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, vi)
7272
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
7373
# affect anything.
7474
# TODO: Stop using `get_and_set_val!`.
7575
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
76-
return r, _loglikelihood(right, r), vi
76+
lp = if ctx.context isa DefaultContext
77+
# MAP
78+
_loglikelihood(right, r)
79+
else
80+
# MLE
81+
0
82+
end
83+
return r, lp, vi
7784
end
7885

7986
"""

test/modes/OptimInterface.jl

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,60 @@
1-
function find_map(model::DynamicPPL.TestUtils.DemoModels)
2-
# Set up.
3-
true_values = rand(NamedTuple, model)
4-
d = length(true_values.s)
5-
s_size, m_size = size(true_values.s), size(true_values.m)
6-
s_isunivariate = true_values.s isa Real
7-
m_isunivariate = true_values.m isa Real
8-
9-
# Cosntruct callable.
10-
function f_wrapped(x)
11-
s = s_isunivariate ? x[1] : reshape(x[1:d], s_size)
12-
m = m_isunivariate ? x[2] : reshape(x[d + 1:end], m_size)
13-
return -DynamicPPL.TestUtils.logjoint_true(model, s, m)
14-
end
1+
# TODO: Remove these once the equivalent is present in `DynamicPPL.TestUtils.
2+
function likelihood_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
3+
return (s=1/16, m=7/4)
4+
end
5+
function posterior_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
6+
# TODO: Figure out exact for `s`.
7+
return (s=0.907407, m=7/6)
8+
end
9+
10+
function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
11+
# Get some containers to fill.
12+
vals = Random.rand(model)
13+
14+
# NOTE: These are "as close to zero as we can get".
15+
vals.s[1] = 1e-32
16+
vals.s[2] = 1e-32
17+
18+
vals.m[1] = 1.5
19+
vals.m[2] = 2.0
20+
21+
return vals
22+
end
23+
function posterior_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
24+
# Get some containers to fill.
25+
vals = Random.rand(model)
26+
27+
# TODO: Figure out exact for `s[1]`.
28+
vals.s[1] = 0.890625
29+
vals.s[2] = 1
30+
vals.m[1] = 3/4
31+
vals.m[2] = 1
1532

16-
# Optimize.
17-
lbs = vcat(fill(0, d), fill(-Inf, d))
18-
ubs = fill(Inf, 2d)
19-
result = optimize(f_wrapped, lbs, ubs, rand(2d), Fminbox(NelderMead()))
20-
@assert Optim.converged(result) "optimization didn't converge"
21-
22-
# Extract the result.
23-
x = Optim.minimizer(result)
24-
s = s_isunivariate ? x[1] : reshape(x[1:d], s_size)
25-
m = m_isunivariate ? x[2] : reshape(x[d + 1:end], m_size)
26-
return -Optim.minimum(result), (s = s, m = m)
33+
return vals
34+
end
35+
36+
# Used for testing how well it works with nested contexts.
37+
struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext
38+
context::C
39+
logprior_weight::T1
40+
loglikelihood_weight::T2
41+
end
42+
DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent()
43+
DynamicPPL.childcontext(parent::OverrideContext) = parent.context
44+
DynamicPPL.setchildcontext(parent::OverrideContext, child) = OverrideContext(
45+
child,
46+
parent.logprior_weight,
47+
parent.loglikelihood_weight
48+
)
49+
50+
# Only implement what we need for the models above.
51+
function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi)
52+
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
53+
return value, context.logprior_weight, vi
54+
end
55+
function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi)
56+
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
57+
return context.loglikelihood_weight, vi
2758
end
2859

2960
@testset "OptimInterface.jl" begin
@@ -126,20 +157,79 @@ end
126157
# FIXME: Some models doesn't work for Tracker and ReverseDiff.
127158
if Turing.Essential.ADBACKEND[] === :forwarddiff
128159
@testset "MAP for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
129-
maximum_true, maximizer_true = find_map(model)
160+
result_true = posterior_optima(model)
130161

131162
@testset "$(optimizer)" for optimizer in [LBFGS(), NelderMead()]
132163
result = optimize(model, MAP(), optimizer)
133164
vals = result.values
134165

135166
for vn in DynamicPPL.TestUtils.varnames(model)
136-
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(maximizer_true, vn))
137-
sym = DynamicPPL.AbstractPPL.getsym(vn_leaf)
138-
true_value_vn = get(maximizer_true, vn_leaf)
139-
@test vals[Symbol(vn_leaf)] true_value_vn rtol = 0.05
167+
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
168+
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol=0.05
169+
end
170+
end
171+
end
172+
end
173+
@testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
174+
result_true = likelihood_optima(model)
175+
176+
# `NelderMead` seems to struggle with convergence here, so we exclude it.
177+
@testset "$(optimizer)" for optimizer in [LBFGS(),]
178+
result = optimize(model, MLE(), optimizer)
179+
vals = result.values
180+
181+
for vn in DynamicPPL.TestUtils.varnames(model)
182+
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
183+
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol=0.05
140184
end
141185
end
142186
end
187+
end
188+
end
189+
190+
# Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320
191+
@testset "OptimizationContext" begin
192+
@model function model1(x)
193+
μ ~ Uniform(0, 2)
194+
x ~ LogNormal(μ, 1)
195+
end
196+
197+
@model function model2()
198+
μ ~ Uniform(0, 2)
199+
x ~ LogNormal(μ, 1)
200+
end
201+
202+
x = 1.0
203+
w = [1.0]
204+
205+
@testset "With ConditionContext" begin
206+
m1 = model1(x)
207+
m2 = model2() | (x = x,)
208+
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
209+
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
210+
end
211+
212+
@testset "With prefixes" begin
213+
function prefix_μ(model)
214+
return DynamicPPL.contextualize(model, DynamicPPL.PrefixContext{:inner}(model.context))
215+
end
216+
m1 = prefix_μ(model1(x))
217+
m2 = prefix_μ(model2() | (var"inner.x" = x,))
218+
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
219+
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
220+
end
221+
222+
@testset "Weighted" begin
223+
function override(model)
224+
return DynamicPPL.contextualize(
225+
model,
226+
OverrideContext(model.context, 100, 1)
227+
)
228+
end
229+
m1 = override(model1(x))
230+
m2 = override(model2() | (x = x,))
231+
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
232+
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
143233
end
144234
end
145235
end

0 commit comments

Comments
 (0)