@@ -42,11 +42,11 @@ struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstract
42
42
context:: C
43
43
44
44
function OptimizationContext {C} (context:: C ) where {C<: DynamicPPL.AbstractContext }
45
- if ! (context isa Union{DynamicPPL. DefaultContext,DynamicPPL. LikelihoodContext})
45
+ if ! (context isa Union{DynamicPPL. DefaultContext,DynamicPPL. LikelihoodContext,DynamicPPL . PriorContext })
46
46
msg = """
47
- `OptimizationContext` supports only leaf contexts of type
48
- `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext`
49
- (given: `$(typeof (context)) )`
47
+ `OptimizationContext` supports only leaf contexts of type
48
+ `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`,
49
+ and `DynamicPPL.PriorContext` (given: `$(typeof (context)) )`
50
50
"""
51
51
throw (ArgumentError (msg))
52
52
end
@@ -60,7 +60,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()
60
60
61
61
function DynamicPPL. tilde_assume (ctx:: OptimizationContext , dist, vn, vi)
62
62
r = vi[vn, dist]
63
- lp = if ctx. context isa DynamicPPL. DefaultContext
63
+ lp = if ctx. context isa Union{ DynamicPPL. DefaultContext,DynamicPPL . PriorContext}
64
64
# MAP
65
65
Distributions. logpdf (dist, r)
66
66
else
@@ -83,7 +83,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
83
83
r = DynamicPPL. get_and_set_val! (
84
84
Random. default_rng (), vi, vns, right, DynamicPPL. SampleFromPrior ()
85
85
)
86
- lp = if ctx. context isa DynamicPPL. DefaultContext
86
+ lp = if ctx. context isa Union{ DynamicPPL. DefaultContext,DynamicPPL . PriorContext}
87
87
# MAP
88
88
_loglikelihood (right, r)
89
89
else
@@ -93,6 +93,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
93
93
return r, lp, vi
94
94
end
95
95
96
+ DynamicPPL. tilde_observe (ctx:: OptimizationContext{<:DynamicPPL.PriorContext} , args... ) =
97
+ DynamicPPL. tilde_observe (ctx. context, args... )
98
+
99
+ DynamicPPL. dot_tilde_observe (ctx:: OptimizationContext{<:DynamicPPL.PriorContext} , args... ) =
100
+ DynamicPPL. dot_tilde_observe (ctx. context, args... )
101
+
96
102
"""
97
103
OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo}
98
104
0 commit comments