Skip to content

Commit 8fab001

Browse files
committed
Add documentation
1 parent 4b820f5 commit 8fab001

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1111
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1212
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
13+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
1314
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1415

1516
[compat]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using Distributions
1111
using DocumenterMermaid
1212
# load MCMCChains package extension to make `predict` available
1313
using MCMCChains
14+
using MarginalLogDensities: MarginalLogDensities
1415

1516
# Doctest setup
1617
DocMeta.setdocmeta!(

docs/src/api.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ The `predict` function has two main methods:
123123
predict
124124
```
125125

126+
## Marginalization
127+
128+
DynamicPPL provides the `marginalize` function to marginalize out variables from a model.
129+
This requires `MarginalLogDensities.jl` to be loaded in your environment.
130+
131+
```@docs
132+
marginalize
133+
```
134+
126135
### Basic Usage
127136

128137
The typical workflow for posterior prediction involves:

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,51 @@ using MarginalLogDensities: MarginalLogDensities
66
_to_varname(n::Symbol) = VarName{n}()
77
_to_varname(n::VarName) = n
88

9+
"""
10+
marginalize(
11+
model::DynamicPPL.Model,
12+
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
13+
getlogprob=DynamicPPL.getlogjoint,
14+
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
15+
kwargs...,
16+
)
17+
18+
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal
19+
log-density of the given `model`, after marginalizing out the variables specified in
20+
`varnames`.
21+
22+
The resulting object can be called with a vector of parameter values to compute the marginal
23+
log-density.
24+
25+
The `getlogprob` argument can be used to specify which kind of marginal log-density to
26+
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
27+
probability.
28+
29+
By default the marginalization is performed with a Laplace approximation. Please see [the
30+
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
31+
for other options.
32+
33+
## Example
34+
35+
```jldoctest
36+
julia> using DynamicPPL, Distributions, MarginalLogDensities
37+
38+
julia> @model function demo()
39+
x ~ Normal(1.0)
40+
y ~ Normal(2.0)
41+
end
42+
demo (generic function with 2 methods)
43+
44+
julia> marginalized = marginalize(demo(), [:x]);
45+
46+
julia> # The resulting callable computes the marginal log-density of `y`.
47+
marginalized([1.0])
48+
-1.4189385332046727
49+
50+
julia> logpdf(Normal(2.0), 1.0)
51+
-1.4189385332046727
52+
```
53+
"""
954
function DynamicPPL.marginalize(
1055
model::DynamicPPL.Model,
1156
varnames::AbstractVector{<:Union{Symbol,<:VarName}},

0 commit comments

Comments
 (0)