Skip to content

Commit e9eabb4

Browse files
committed
Make the non-essential stuff all keyword arguments
1 parent 00c08b2 commit e9eabb4

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ _to_varname(n::VarName) = n
99
"""
1010
marginalize(
1111
model::DynamicPPL.Model,
12-
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
12+
varnames::AbstractVector{<:Union{Symbol,<:VarName}};
1313
varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model),
1414
getlogprob=DynamicPPL.getlogjoint,
1515
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
@@ -23,17 +23,22 @@ log-density of the given `model`, after marginalizing out the variables specifie
2323
The resulting object can be called with a vector of parameter values to compute the marginal
2424
log-density.
2525
26-
You can specify the `varinfo` to use for the model. By default we use a linked `VarInfo`,
27-
meaning that the resulting log-density function accepts parameters that have been
28-
transformed to unconstrained space.
26+
## Keyword arguments
2927
30-
The `getlogprob` argument can be used to specify which kind of marginal log-density to
31-
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
32-
probability.
28+
- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`,
29+
meaning that the resulting log-density function accepts parameters that have bee_FWDn
30+
transformed to unconstrained space.
3331
34-
By default the marginalization is performed with a Laplace approximation. Please see [the
35-
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
36-
for other options.
32+
- `getlogprob`: A function which specifies which kind of marginal log-density to compute.
33+
Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
34+
probability.
35+
36+
- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the
37+
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
38+
for other options.
39+
40+
- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity`
41+
constructor.
3742
3843
## Example
3944
@@ -58,10 +63,10 @@ julia> logpdf(Normal(2.0), 1.0)
5863
"""
5964
function DynamicPPL.marginalize(
6065
model::DynamicPPL.Model,
61-
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
66+
varnames::AbstractVector{<:Union{Symbol,<:VarName}};
6267
varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model),
6368
getlogprob::Function=DynamicPPL.getlogjoint,
64-
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
69+
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
6570
kwargs...,
6671
)
6772
# Determine the indices for the variables to marginalise out.

test/ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ using ADTypes: AutoForwardDiff
1717
for vn in [@varname(x), :x]
1818
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
1919
marginalized = marginalize(
20-
model, [vn], vi, getlogprob; hess_adtype=AutoForwardDiff()
20+
model,
21+
[vn];
22+
varinfo=vi,
23+
getlogprob=getlogprob,
24+
hess_adtype=AutoForwardDiff(),
2125
)
2226
for y in range(-5, 5; length=100)
2327
@test marginalized([y]) logpdf(Normal(0, 1), y) atol = 1e-5
@@ -36,27 +40,29 @@ using ADTypes: AutoForwardDiff
3640
vi_linked = DynamicPPL.link(vi_unlinked, model)
3741

3842
@testset "unlinked VarInfo" begin
39-
mx = marginalize(model, [@varname(x)], vi_unlinked)
43+
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
4044
for x in range(0.01, 0.99; length=10)
4145
@test mx([x]) logpdf(Beta(2, 2), x)
4246
end
4347
# generally when marginalising Beta it doesn't go to zero
44-
my = marginalize(model, [@varname(y)], vi_unlinked)
48+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067
49+
my = marginalize(model, [@varname(y)]; varinfo=vi_unlinked)
4550
diff = my([0.0]) - logpdf(Normal(), 0.0)
4651
for x in range(-5, 5; length=10)
4752
@test my([x]) logpdf(Normal(), x) + diff
4853
end
4954
end
5055

5156
@testset "linked VarInfo" begin
52-
mx = marginalize(model, [@varname(x)], vi_linked)
57+
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
5358
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
5459
for y_linked in range(-5, 5; length=10)
5560
y_unlinked = binv(y_linked)
5661
@test mx([y_linked]) logpdf(Beta(2, 2), y_unlinked)
5762
end
5863
# generally when marginalising Beta it doesn't go to zero
59-
my = marginalize(model, [@varname(y)], vi_linked)
64+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067
65+
my = marginalize(model, [@varname(y)]; varinfo=vi_linked)
6066
diff = my([0.0]) - logpdf(Normal(), 0.0)
6167
for x in range(-5, 5; length=10)
6268
@test my([x]) logpdf(Normal(), x) + diff

0 commit comments

Comments
 (0)