@@ -6,6 +6,16 @@ using MarginalLogDensities: MarginalLogDensities
6
6
_to_varname (n:: Symbol ) = VarName {n} ()
7
7
_to_varname (n:: VarName ) = n
8
8
9
+ # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
10
+ # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
11
+ # below.
12
+ struct LogDensityFunctionWrapper{L<: DynamicPPL.LogDensityFunction }
13
+ logdensity:: L
14
+ end
15
+ function (lw:: LogDensityFunctionWrapper )(x, _)
16
+ return LogDensityProblems. logdensity (lw. logdensity, x)
17
+ end
18
+
9
19
"""
10
20
marginalize(
11
21
model::DynamicPPL.Model,
@@ -26,7 +36,7 @@ log-density.
26
36
## Keyword arguments
27
37
28
38
- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`,
29
- meaning that the resulting log-density function accepts parameters that have bee_FWDn
39
+ meaning that the resulting log-density function accepts parameters that have been
30
40
transformed to unconstrained space.
31
41
32
42
- `getlogprob`: A function which specifies which kind of marginal log-density to compute.
@@ -60,6 +70,26 @@ julia> # The resulting callable computes the marginal log-density of `y`.
60
70
julia> logpdf(Normal(2.0), 1.0)
61
71
-1.4189385332046727
62
72
```
73
+
74
+
75
+ !!! warning
76
+
77
+ The default usage of linked VarInfo means that, for example, optimization of the
78
+ marginal log-density can be performed in unconstrained space. However, care must be
79
+ taken if the model contains variables where the link transformation depends on a
80
+ marginalized variable. For example:
81
+
82
+ ```julia
83
+ @model function f()
84
+ x ~ Normal()
85
+ y ~ truncated(Normal(); lower=x)
86
+ end
87
+ ```
88
+
89
+ Here, the support of `y`, and hence the link transformation used, depends on the value
90
+ of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of
91
+ `y` to log-probabilities. However, it will not be possible to use DynamicPPL to
92
+ correctly retrieve _unlinked_ values of `y`.
63
93
"""
64
94
function DynamicPPL. marginalize (
65
95
model:: DynamicPPL.Model ,
@@ -74,15 +104,104 @@ function DynamicPPL.marginalize(
74
104
varindices = reduce (vcat, DynamicPPL. vector_getranges (varinfo, vns))
75
105
# Construct the marginal log-density model.
76
106
f = DynamicPPL. LogDensityFunction (model, getlogprob, varinfo)
77
- mdl = MarginalLogDensities. MarginalLogDensity (
78
- (x, _) -> LogDensityProblems. logdensity (f, x),
79
- varinfo[:],
80
- varindices,
81
- (),
82
- method;
83
- kwargs... ,
107
+ mld = MarginalLogDensities. MarginalLogDensity (
108
+ LogDensityFunctionWrapper (f), varinfo[:], varindices, (), method; kwargs...
109
+ )
110
+ return mld
111
+ end
112
+
113
+ """
114
+ VarInfo(
115
+ mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
116
+ unmarginalized_params::Union{AbstractVector,Nothing}=nothing
84
117
)
85
- return mdl
118
+
119
+ Retrieve the `VarInfo` object used in the marginalisation process.
120
+
121
+ If a Laplace approximation was used for the marginalisation, the values of the marginalized
122
+ parameters are also set to their mode (note that this only happens if the `mld` object has
123
+ been used to compute the marginal log-density at least once, so that the mode has been
124
+ computed).
125
+
126
+ If a vector of `unmarginalized_params` is specified, the values for the corresponding
127
+ parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
128
+ performing an optimization of the marginal log-density.
129
+
130
+ All other aspects of the VarInfo, such as link status, are preserved from the original
131
+ VarInfo used in the marginalisation.
132
+
133
+ !!! note
134
+
135
+ The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
136
+ updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the
137
+ model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model,
138
+ vi))`).
139
+
140
+ ## Example
141
+
142
+ ```jldoctest
143
+ julia> using DynamicPPL, Distributions, MarginalLogDensities
144
+
145
+ julia> @model function demo()
146
+ x ~ Normal()
147
+ y ~ Beta(2, 2)
148
+ end
149
+ demo (generic function with 2 methods)
150
+
151
+ julia> # Note that by default `marginalize` uses a linked VarInfo.
152
+ mld = marginalize(demo(), [@varname(x)]);
153
+
154
+ julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
155
+
156
+ julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
157
+ y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
158
+ OptimizationProblem. In-place: true
159
+ u0: 1-element Vector{Float64}:
160
+ 2.0
161
+
162
+ julia> # This tells us the optimal (linked) value of `y` is around 0.
163
+ opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
164
+ retcode: Success
165
+ u: 1-element Vector{Float64}:
166
+ 4.88281250001733e-5
167
+
168
+ julia> # Get the VarInfo corresponding to the mode of `y`.
169
+ vi = VarInfo(mld, opt_solution.u);
170
+
171
+ julia> # `x` is set to its mode (which for `Normal()` is zero).
172
+ vi[@varname(x)]
173
+ 0.0
174
+
175
+ julia> # `y` is set to the optimal value we found above.
176
+ DynamicPPL.getindex_internal(vi, @varname(y))
177
+ 1-element Vector{Float64}:
178
+ 4.88281250001733e-5
179
+
180
+ julia> # To obtain values in the original constrained space, we can either
181
+ # use `getindex`:
182
+ vi[@varname(y)]
183
+ 0.5000122070312476
184
+
185
+ julia> # Or invlink the entire VarInfo object using the model:
186
+ vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
187
+ 2-element Vector{Float64}:
188
+ 0.0
189
+ 0.5000122070312476
190
+ ```
191
+ """
192
+ function DynamicPPL. VarInfo (
193
+ mld:: MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper} ,
194
+ unmarginalized_params:: Union{AbstractVector,Nothing} = nothing ,
195
+ )
196
+ # Extract the original VarInfo. Its contents will in general be junk.
197
+ original_vi = mld. logdensity. logdensity. varinfo
198
+ # `mld.u` will contain the modes for any marginalized parameters
199
+ full_params = mld. u
200
+ # We can then set the values for any non-marginalized parameters
201
+ if unmarginalized_params != = nothing
202
+ full_params[MarginalLogDensities. ijoint (mld)] = unmarginalized_params
203
+ end
204
+ return DynamicPPL. unflatten (original_vi, full_params)
86
205
end
87
206
88
207
end
0 commit comments