Skip to content

Commit f135821

Browse files
alystmhauru
andauthored
reenable PriorContext for Optimization (#2165)
Co-authored-by: Markus Hauru <[email protected]>
1 parent 03d0e78 commit f135821

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

src/optimisation/Optimisation.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstract
4242
context::C
4343

4444
function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext}
45-
if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext})
45+
if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext,DynamicPPL.PriorContext})
4646
msg = """
47-
`OptimizationContext` supports only leaf contexts of type
48-
`DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext`
49-
(given: `$(typeof(context)))`
47+
`OptimizationContext` supports only leaf contexts of type
48+
`DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`,
49+
and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`
5050
"""
5151
throw(ArgumentError(msg))
5252
end
@@ -60,7 +60,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()
6060

6161
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
6262
r = vi[vn, dist]
63-
lp = if ctx.context isa DynamicPPL.DefaultContext
63+
lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext}
6464
# MAP
6565
Distributions.logpdf(dist, r)
6666
else
@@ -83,7 +83,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
8383
r = DynamicPPL.get_and_set_val!(
8484
Random.default_rng(), vi, vns, right, DynamicPPL.SampleFromPrior()
8585
)
86-
lp = if ctx.context isa DynamicPPL.DefaultContext
86+
lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext}
8787
# MAP
8888
_loglikelihood(right, r)
8989
else
@@ -93,6 +93,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
9393
return r, lp, vi
9494
end
9595

96+
DynamicPPL.tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) =
97+
DynamicPPL.tilde_observe(ctx.context, args...)
98+
99+
DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) =
100+
DynamicPPL.dot_tilde_observe(ctx.context, args...)
101+
96102
"""
97103
OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo}
98104

test/optimisation/Optimisation.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ using Turing
9595
@test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) ==
9696
Turing.Optimisation.OptimLogDensity(m2, ctx)(w)
9797
end
98+
99+
@testset "Default, Likelihood, Prior Contexts" begin
100+
m1 = model1(x)
101+
defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
102+
llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
103+
prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext())
104+
a = [0.3]
105+
106+
@test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) ==
107+
Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) +
108+
Turing.Optimisation.OptimLogDensity(m1, prictx)(a)
109+
110+
# test that PriorContext is calculating the right thing
111+
@test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) -Distributions.logpdf(Uniform(0, 2), 0.3)
112+
@test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) -Distributions.logpdf(Uniform(0, 2), -0.3)
113+
end
98114
end
99115

100116
@testset "gdemo" begin

0 commit comments

Comments
 (0)