Skip to content

Commit f4049c1

Browse files
committed
Add VarInfo(::MarginalLogDensity) method
1 parent 844ec4c commit f4049c1

File tree

4 files changed

+170
-9
lines changed

4 files changed

+170
-9
lines changed

docs/make.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ using DocumenterMermaid
1313
using MCMCChains
1414
using MarginalLogDensities: MarginalLogDensities
1515

16+
# Need this to document a method which uses a type inside the extension...
17+
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)
18+
1619
# Doctest setup
1720
DocMeta.setdocmeta!(
1821
DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ This requires `MarginalLogDensities.jl` to be loaded in your environment.
145145
marginalize
146146
```
147147

148+
A `MarginalLogDensity` object acts as a function which maps non-marginalized parameter values to a marginal log-probability.
149+
To retrieve a VarInfo object from it, you can use:
150+
151+
```@docs
152+
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
153+
```
154+
148155
## Models within models
149156

150157
One can include models and call another model inside the model function with `left ~ to_submodel(model)`.

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ using MarginalLogDensities: MarginalLogDensities
66
_to_varname(n::Symbol) = VarName{n}()
77
_to_varname(n::VarName) = n
88

9+
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
10+
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
11+
# below.
12+
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
13+
logdensity::L
14+
end
15+
function (lw::LogDensityFunctionWrapper)(x, _)
16+
return LogDensityProblems.logdensity(lw.logdensity, x)
17+
end
18+
919
"""
1020
marginalize(
1121
model::DynamicPPL.Model,
@@ -26,7 +36,7 @@ log-density.
2636
## Keyword arguments
2737
2838
- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`,
29-
meaning that the resulting log-density function accepts parameters that have bee_FWDn
39+
meaning that the resulting log-density function accepts parameters that have been
3040
transformed to unconstrained space.
3141
3242
- `getlogprob`: A function which specifies which kind of marginal log-density to compute.
@@ -60,6 +70,26 @@ julia> # The resulting callable computes the marginal log-density of `y`.
6070
julia> logpdf(Normal(2.0), 1.0)
6171
-1.4189385332046727
6272
```
73+
74+
75+
!!! warning
76+
77+
The default usage of linked VarInfo means that, for example, optimization of the
78+
marginal log-density can be performed in unconstrained space. However, care must be
79+
taken if the model contains variables where the link transformation depends on a
80+
marginalized variable. For example:
81+
82+
```julia
83+
@model function f()
84+
x ~ Normal()
85+
y ~ truncated(Normal(); lower=x)
86+
end
87+
```
88+
89+
Here, the support of `y`, and hence the link transformation used, depends on the value
90+
of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of
91+
`y` to log-probabilities. However, it will not be possible to use DynamicPPL to
92+
correctly retrieve _unlinked_ values of `y`.
6393
"""
6494
function DynamicPPL.marginalize(
6595
model::DynamicPPL.Model,
@@ -74,15 +104,104 @@ function DynamicPPL.marginalize(
74104
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
75105
# Construct the marginal log-density model.
76106
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
77-
mdl = MarginalLogDensities.MarginalLogDensity(
78-
(x, _) -> LogDensityProblems.logdensity(f, x),
79-
varinfo[:],
80-
varindices,
81-
(),
82-
method;
83-
kwargs...,
107+
mld = MarginalLogDensities.MarginalLogDensity(
108+
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
109+
)
110+
return mld
111+
end
112+
113+
"""
114+
VarInfo(
115+
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
116+
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
84117
)
85-
return mdl
118+
119+
Retrieve the `VarInfo` object used in the marginalisation process.
120+
121+
If a Laplace approximation was used for the marginalisation, the values of the marginalized
122+
parameters are also set to their mode (note that this only happens if the `mld` object has
123+
been used to compute the marginal log-density at least once, so that the mode has been
124+
computed).
125+
126+
If a vector of `unmarginalized_params` is specified, the values for the corresponding
127+
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
128+
performing an optimization of the marginal log-density.
129+
130+
All other aspects of the VarInfo, such as link status, are preserved from the original
131+
VarInfo used in the marginalisation.
132+
133+
!!! note
134+
135+
The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
136+
updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the
137+
model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model,
138+
vi))`).
139+
140+
## Example
141+
142+
```jldoctest
143+
julia> using DynamicPPL, Distributions, MarginalLogDensities
144+
145+
julia> @model function demo()
146+
x ~ Normal()
147+
y ~ Beta(2, 2)
148+
end
149+
demo (generic function with 2 methods)
150+
151+
julia> # Note that by default `marginalize` uses a linked VarInfo.
152+
mld = marginalize(demo(), [@varname(x)]);
153+
154+
julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
155+
156+
julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
157+
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
158+
OptimizationProblem. In-place: true
159+
u0: 1-element Vector{Float64}:
160+
2.0
161+
162+
julia> # This tells us the optimal (linked) value of `y` is around 0.
163+
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
164+
retcode: Success
165+
u: 1-element Vector{Float64}:
166+
4.88281250001733e-5
167+
168+
julia> # Get the VarInfo corresponding to the mode of `y`.
169+
vi = VarInfo(mld, opt_solution.u);
170+
171+
julia> # `x` is set to its mode (which for `Normal()` is zero).
172+
vi[@varname(x)]
173+
0.0
174+
175+
julia> # `y` is set to the optimal value we found above.
176+
DynamicPPL.getindex_internal(vi, @varname(y))
177+
1-element Vector{Float64}:
178+
4.88281250001733e-5
179+
180+
julia> # To obtain values in the original constrained space, we can either
181+
# use `getindex`:
182+
vi[@varname(y)]
183+
0.5000122070312476
184+
185+
julia> # Or invlink the entire VarInfo object using the model:
186+
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
187+
2-element Vector{Float64}:
188+
0.0
189+
0.5000122070312476
190+
```
191+
"""
192+
function DynamicPPL.VarInfo(
193+
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
194+
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
195+
)
196+
# Extract the original VarInfo. Its contents will in general be junk.
197+
original_vi = mld.logdensity.logdensity.varinfo
198+
# `mld.u` will contain the modes for any marginalized parameters
199+
full_params = mld.u
200+
# We can then set the values for any non-marginalized parameters
201+
if unmarginalized_params !== nothing
202+
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
203+
end
204+
return DynamicPPL.unflatten(original_vi, full_params)
86205
end
87206

88207
end

test/ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,38 @@ using ADTypes: AutoForwardDiff
6969
end
7070
end
7171
end
72+
73+
@testset "retrieving VarInfo from MLD" begin
74+
@model function f()
75+
x ~ Normal()
76+
return y ~ Beta(2, 2)
77+
end
78+
model = f()
79+
vi_unlinked = VarInfo(model)
80+
vi_linked = DynamicPPL.link(vi_unlinked, model)
81+
82+
@testset "unlinked VarInfo" begin
83+
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
84+
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
85+
vi = VarInfo(mx)
86+
@test vi[@varname(x)] mode(Normal())
87+
vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked
88+
@test vi[@varname(x)] mode(Normal())
89+
@test vi[@varname(y)] 0.5
90+
end
91+
92+
@testset "linked VarInfo" begin
93+
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
94+
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
95+
vi = VarInfo(mx)
96+
@test vi[@varname(x)] mode(Normal())
97+
vi = VarInfo(mx, [0.5]) # this 0.5 is linked
98+
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
99+
@test vi[@varname(x)] mode(Normal())
100+
# when using getindex it always returns unlinked values
101+
@test vi[@varname(y)] binv(0.5)
102+
end
103+
end
72104
end
73105

74106
end

0 commit comments

Comments
 (0)