Skip to content

Commit 14f9f46

Browse files
committed
Allow user to specify VarInfo used for marginalisation
1 parent aaca138 commit 14f9f46

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ _to_varname(n::VarName) = n
99
"""
1010
marginalize(
1111
model::DynamicPPL.Model,
12+
varinfo::DynamicPPL.AbstractVarInfo=VarInfo(model),
1213
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
1314
getlogprob=DynamicPPL.getlogjoint,
1415
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
@@ -54,22 +55,19 @@ julia> logpdf(Normal(2.0), 1.0)
5455
function DynamicPPL.marginalize(
5556
model::DynamicPPL.Model,
5657
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
57-
getlogprob=DynamicPPL.getlogjoint,
58+
varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model),
59+
getlogprob::Function=DynamicPPL.getlogjoint,
5860
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
5961
kwargs...,
6062
)
6163
# Determine the indices for the variables to marginalise out.
62-
varinfo = DynamicPPL.typed_varinfo(model)
6364
vns = map(_to_varname, varnames)
6465
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
6566
# Construct the marginal log-density model.
66-
# Use linked `varinfo` to that we're working in unconstrained space
67-
varinfo_linked = DynamicPPL.link(varinfo, model)
68-
69-
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked)
67+
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
7068
mdl = MarginalLogDensities.MarginalLogDensity(
7169
(x, _) -> LogDensityProblems.logdensity(f, x),
72-
varinfo_linked[:],
70+
varinfo[:],
7371
varindices,
7472
(),
7573
method;

test/ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,66 @@
11
module MarginalLogDensitiesExtTests
22

3+
using Bijectors: Bijectors
34
using DynamicPPL, Distributions, Test
45
using MarginalLogDensities
56
using ADTypes: AutoForwardDiff
67

78
@testset "MarginalLogDensities" begin
8-
# Simple test case.
9-
@model function demo()
10-
x ~ MvNormal(zeros(2), [1, 1])
11-
return y ~ Normal(0, 1)
9+
@testset "Basic usage" begin
10+
@model function demo()
11+
x ~ MvNormal(zeros(2), [1, 1])
12+
return y ~ Normal(0, 1)
13+
end
14+
model = demo()
15+
vi = VarInfo(model)
16+
# Marginalize out `x`.
17+
for vn in [@varname(x), :x]
18+
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
19+
marginalized = marginalize(
20+
model, [vn], vi, getlogprob; hess_adtype=AutoForwardDiff()
21+
)
22+
for y in range(-5, 5; length=100)
23+
@test marginalized([y]) logpdf(Normal(0, 1), y) atol = 1e-5
24+
end
25+
end
26+
end
1227
end
13-
model = demo()
14-
# Marginalize out `x`.
1528

16-
for vn in [@varname(x), :x]
17-
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
18-
marginalized = marginalize(
19-
model, [vn], getlogprob; hess_adtype=AutoForwardDiff()
20-
)
21-
# Compute the marginal log-density of `y = 0.0`.
22-
@test marginalized([0.0]) logpdf(Normal(0, 1), 0.0) atol = 1e-5
29+
@testset "Respects linked status of VarInfo" begin
30+
@model function f()
31+
x ~ Normal()
32+
return y ~ Beta(2, 2)
33+
end
34+
model = f()
35+
vi_unlinked = VarInfo(model)
36+
vi_linked = DynamicPPL.link(vi_unlinked, model)
37+
38+
@testset "unlinked VarInfo" begin
39+
mx = marginalize(model, [@varname(x)], vi_unlinked)
40+
for x in range(0.01, 0.99; length=10)
41+
@test mx([x]) logpdf(Beta(2, 2), x)
42+
end
43+
# generally when marginalising Beta it doesn't go to zero
44+
my = marginalize(model, [@varname(y)], vi_unlinked)
45+
diff = my([0.0]) - logpdf(Normal(), 0.0)
46+
for x in range(-5, 5; length=10)
47+
@test my([x]) logpdf(Normal(), x) + diff
48+
end
49+
end
50+
51+
@testset "linked VarInfo" begin
52+
mx = marginalize(model, [@varname(x)], vi_linked)
53+
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
54+
for y_linked in range(-5, 5; length=10)
55+
y_unlinked = binv(y_linked)
56+
@test mx([y_linked]) logpdf(Beta(2, 2), y_unlinked)
57+
end
58+
# generally when marginalising Beta it doesn't go to zero
59+
my = marginalize(model, [@varname(y)], vi_linked)
60+
diff = my([0.0]) - logpdf(Normal(), 0.0)
61+
for x in range(-5, 5; length=10)
62+
@test my([x]) logpdf(Normal(), x) + diff
63+
end
2364
end
2465
end
2566
end

0 commit comments

Comments
 (0)