@@ -3,9 +3,6 @@ module DynamicPPLMarginalLogDensitiesExt
3
3
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
4
4
using MarginalLogDensities: MarginalLogDensities
5
5
6
- _to_varname (n:: Symbol ) = VarName {n} ()
7
- _to_varname (n:: VarName ) = n
8
-
9
6
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
10
7
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
11
8
# below.
19
16
"""
20
17
marginalize(
21
18
model::DynamicPPL.Model,
22
- varnames ::AbstractVector{<:Union{Symbol,<: VarName} };
19
+ marginalized_varnames ::AbstractVector{<:VarName};
23
20
varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model),
24
21
getlogprob=DynamicPPL.getlogjoint,
25
22
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
@@ -93,15 +90,14 @@ julia> logpdf(Normal(2.0), 1.0)
93
90
"""
94
91
function DynamicPPL. marginalize (
95
92
model:: DynamicPPL.Model ,
96
- varnames :: AbstractVector{<:Union{Symbol,<: VarName} } ;
93
+ marginalized_varnames :: AbstractVector{<:VarName} ;
97
94
varinfo:: DynamicPPL.AbstractVarInfo = DynamicPPL. link (DynamicPPL. VarInfo (model), model),
98
95
getlogprob:: Function = DynamicPPL. getlogjoint,
99
96
method:: MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities. LaplaceApprox (),
100
97
kwargs... ,
101
98
)
102
99
# Determine the indices for the variables to marginalise out.
103
- vns = map (_to_varname, varnames)
104
- varindices = reduce (vcat, DynamicPPL. vector_getranges (varinfo, vns))
100
+ varindices = reduce (vcat, DynamicPPL. vector_getranges (varinfo, marginalized_varnames))
105
101
# Construct the marginal log-density model.
106
102
f = DynamicPPL. LogDensityFunction (model, getlogprob, varinfo)
107
103
mld = MarginalLogDensities. MarginalLogDensity (
0 commit comments