Skip to content

Commit 078731f

Browse files
author
Carlos Parada
authored
Merge branch 'master' into enhance_wrapped_distr
2 parents e9291c7 + 8c8cfc6 commit 078731f

24 files changed

+1497
-477
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.20.2"
3+
version = "0.21.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -30,6 +30,6 @@ Distributions = "0.23.8, 0.24, 0.25"
3030
DocStringExtensions = "0.8, 0.9"
3131
MacroTools = "0.5.6"
3232
OrderedCollections = "1"
33-
Setfield = "0.7.1, 0.8"
33+
Setfield = "0.7.1, 0.8, 1"
3434
ZygoteRules = "0.2"
3535
julia = "1.6"

docs/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
45
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
56
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
67

78
[compat]
89
Distributions = "0.25"
910
Documenter = "0.27"
10-
Setfield = "0.7.1, 0.8"
11+
FillArrays = "0.13"
12+
Setfield = "0.7.1, 0.8, 1"
1113
StableRNGs = "1"

docs/make.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@ using DynamicPPL
33
using DynamicPPL: AbstractPPL
44

55
# Doctest setup
6-
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
6+
DocMeta.setdocmeta!(
7+
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
8+
)
79

810
makedocs(;
911
sitename="DynamicPPL",
1012
format=Documenter.HTML(),
1113
modules=[DynamicPPL],
12-
pages=["Home" => "index.md", "API" => "api.md"],
14+
pages=[
15+
"Home" => "index.md",
16+
"API" => "api.md",
17+
"Tutorials" => ["tutorials/prob-interface.md"],
18+
],
1319
strict=true,
1420
checkdocs=:exports,
1521
)

docs/src/api.md

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,56 @@ AbstractVarInfo
156156

157157
### Common API
158158

159+
#### Accumulation of log-probabilities
160+
159161
```@docs
160162
getlogp
161163
setlogp!!
162164
acclogp!!
163165
resetlogp!!
164166
```
165167

168+
#### Variables and their realizations
169+
166170
```@docs
171+
keys
167172
getindex
173+
DynamicPPL.getindex_raw
168174
push!!
169175
empty!!
176+
isempty
170177
```
171178

172179
```@docs
173180
values_as
174181
```
175182

183+
#### Transformations
184+
185+
```@docs
186+
DynamicPPL.AbstractTransformation
187+
DynamicPPL.NoTransformation
188+
DynamicPPL.DynamicTransformation
189+
DynamicPPL.StaticTransformation
190+
```
191+
192+
```@docs
193+
DynamicPPL.istrans
194+
DynamicPPL.settrans!!
195+
DynamicPPL.transformation
196+
DynamicPPL.link!!
197+
DynamicPPL.invlink!!
198+
DynamicPPL.default_transformation
199+
DynamicPPL.maybe_invlink_before_eval!!
200+
```
201+
202+
#### Utils
203+
204+
```@docs
205+
DynamicPPL.unflatten
206+
DynamicPPL.tonamedtuple
207+
```
208+
176209
#### `SimpleVarInfo`
177210

178211
```@docs
@@ -191,10 +224,8 @@ TypedVarInfo
191224
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
192225

193226
```@docs
194-
tonamedtuple
195227
link!
196228
invlink!
197-
istrans
198229
```
199230

200231
```@docs
File renamed without changes.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# The Probability Interface
2+
3+
The easiest way to manipulate and query DynamicPPL models is via the DynamicPPL probability
4+
interface.
5+
6+
Let's use a simple model of normally-distributed data as an example.
7+
```@example probinterface
8+
using DynamicPPL
9+
using Distributions
10+
using FillArrays
11+
using LinearAlgebra
12+
using Random
13+
14+
Random.seed!(1776) # Set seed for reproducibility
15+
16+
@model function gdemo(n)
17+
μ ~ Normal(0, 1)
18+
x ~ MvNormal(Fill(μ, n), I)
19+
return nothing
20+
end
21+
nothing # hide
22+
```
23+
24+
We generate some data using `μ = 0` and `σ = 1`:
25+
26+
```@example probinterface
27+
dataset = randn(100)
28+
nothing # hide
29+
```
30+
31+
## Conditioning and Deconditioning
32+
33+
Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization).
34+
Conditioning takes a variable and fixes its value as known.
35+
We do this by passing a model and a named tuple of conditioned variables to `|`:
36+
```@example probinterface
37+
model = gdemo(length(dataset)) | (x=dataset, μ=0, σ=1)
38+
nothing # hide
39+
```
40+
41+
This operation can be reversed by applying `decondition`:
42+
```@example probinterface
43+
decondition(model)
44+
nothing # hide
45+
```
46+
47+
We can also decondition only some of the variables:
48+
```@example probinterface
49+
decondition(model, :μ)
50+
nothing # hide
51+
```
52+
53+
## Probabilities and Densities
54+
55+
We often want to calculate the (unnormalized) probability density for an event.
56+
This probability might be a prior, a likelihood, or a posterior (joint) density.
57+
DynamicPPL provides convenient functions for this.
58+
For example, if we wanted to calculate the probability of a draw from the prior:
59+
```@example probinterface
60+
model = gdemo(length(dataset)) | (x=dataset,)
61+
x1 = rand(model)
62+
logjoint(model, x1)
63+
```
64+
65+
For convenience, we provide the functions `loglikelihood` and `logjoint` to calculate probabilities for a named tuple, given a model:
66+
```@example probinterface
67+
@assert logjoint(model, x1) ≈ loglikelihood(model, x1) + logprior(model, x1)
68+
```
69+
70+
## Example: Cross-validation
71+
72+
To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).¹
73+
``` @example probinterface
74+
function cross_val(model, dataset)
75+
training_loss = zero(logjoint(model, rand(model)))
76+
77+
# Partition our dataset into 5 folds with 20 observations:
78+
test_folds = collect(Iterators.partition(dataset, 20))
79+
train_folds = setdiff.((dataset,), test_folds)
80+
81+
for (train, test) in zip(train_folds, test_folds)
82+
# First, we train the model on the training set.
83+
# For normally-distributed data, the posterior can be solved in closed form:
84+
posterior = Normal(mean(train), 1)
85+
# Sample from the posterior
86+
samples = NamedTuple{(:μ,)}.(rand(posterior, 1000))
87+
# Test
88+
testing_model = gdemo(length(test)) | (x = test,)
89+
training_loss += sum(samples) do sample
90+
logjoint(testing_model, sample)
91+
end
92+
end
93+
return training_loss
94+
end
95+
cross_val(model, dataset)
96+
```
97+
98+
¹See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.

src/DynamicPPL.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ export AbstractVarInfo,
5959
setorder!,
6060
istrans,
6161
link!,
62+
link!!,
6263
invlink!,
64+
invlink!!,
6365
tonamedtuple,
6466
values_as,
6567
# VarName (reexport from AbstractPPL)
@@ -126,27 +128,33 @@ export loglikelihood
126128
# Used here and overloaded in Turing
127129
function getspace end
128130

129-
# Necessary forward declarations
130131
"""
131132
AbstractVarInfo
132133
133134
Abstract supertype for data structures that capture random variables when executing a
134135
probabilistic model and accumulate log densities such as the log likelihood or the
135136
log joint probability of the model.
136137
137-
See also: [`VarInfo`](@ref)
138+
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
138139
"""
139140
abstract type AbstractVarInfo <: AbstractModelTrace end
140141

142+
const LEGACY_WARNING = """
143+
!!! warning
144+
This method is considered legacy, and is likely to be deprecated in the future.
145+
"""
146+
147+
# Necessary forward declarations
141148
include("utils.jl")
142149
include("selector.jl")
143150
include("model.jl")
144151
include("sampler.jl")
145152
include("varname.jl")
146153
include("distribution_wrappers.jl")
147154
include("contexts.jl")
148-
include("varinfo.jl")
155+
include("abstract_varinfo.jl")
149156
include("threadsafe.jl")
157+
include("varinfo.jl")
150158
include("simple_varinfo.jl")
151159
include("context_implementations.jl")
152160
include("compiler.jl")
@@ -155,5 +163,6 @@ include("compat/ad.jl")
155163
include("loglikelihoods.jl")
156164
include("submodel_macro.jl")
157165
include("test_utils.jl")
166+
include("transforming.jl")
158167

159168
end # module

0 commit comments

Comments
 (0)