Skip to content

Commit 6f255d1

Browse files
Merge branch 'master' of https://github.com/TuringLang/DynamicPPL.jl into tor/benchmark-update
2 parents 0291c2f + 29a6c7e commit 6f255d1

21 files changed

+712
-231
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.32.2"
3+
version = "0.34.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1515
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1616
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1717
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18+
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19+
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
20+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1821
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1922
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2023
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -46,7 +49,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4649
[compat]
4750
ADTypes = "1"
4851
AbstractMCMC = "5"
49-
AbstractPPL = "0.8.4, 0.9"
52+
AbstractPPL = "0.10.1"
5053
Accessors = "0.1"
5154
BangBang = "0.4.1"
5255
Bijectors = "0.13.18, 0.14, 0.15"
@@ -55,6 +58,9 @@ Compat = "4"
5558
ConstructionBase = "1.5.4"
5659
Distributions = "0.25"
5760
DocStringExtensions = "0.9"
61+
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62+
# for why KernelAbstractions is pinned like this.
63+
KernelAbstractions = "< 0.9.32"
5864
EnzymeCore = "0.6 - 0.8"
5965
ForwardDiff = "0.10"
6066
JET = "0.9"

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
6565
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
6666

6767
```@docs
68-
|(::Model, ::Any)
68+
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
6969
condition
7070
DynamicPPL.conditioned
7171
```
@@ -403,6 +403,7 @@ LikelihoodContext
403403
PriorContext
404404
MiniBatchContext
405405
PrefixContext
406+
ConditionContext
406407
```
407408

408409
### Samplers

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
"""
46+
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
47+
48+
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
49+
in `chain`, and return the resulting `Chains`.
50+
51+
The `model` passed to `predict` is often different from the one used to generate `chain`.
52+
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53+
data points), while the model you pass to `predict` may mark these same variables as missing
54+
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
55+
simulate what new, unobserved data might look like, given your posterior beliefs.
56+
57+
For each parameter configuration in `chain`:
58+
1. All random variables present in `chain` are fixed to their sampled values.
59+
2. Any variables not included in `chain` are sampled from their prior distributions.
60+
61+
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
62+
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
63+
predictive distribution.
64+
65+
# Examples
66+
```jldoctest
67+
using AbstractMCMC, Distributions, DynamicPPL, Random
68+
69+
@model function linear_reg(x, y, σ = 0.1)
70+
β ~ Normal(0, 1)
71+
for i in eachindex(y)
72+
y[i] ~ Normal(β * x[i], σ)
73+
end
74+
end
75+
76+
# Generate synthetic chain using known ground truth parameter
77+
ground_truth_β = 2.0
78+
79+
# Create chain of samples from a normal distribution centered on ground truth
80+
β_chain = MCMCChains.Chains(
81+
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
82+
)
83+
84+
# Generate predictions for two test points
85+
xs_test = [10.1, 10.2]
86+
87+
m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
88+
89+
predictions = DynamicPPL.AbstractPPL.predict(
90+
Random.default_rng(), m_train, β_chain
91+
)
92+
93+
ys_pred = vec(mean(Array(predictions); dims=1))
94+
95+
# Check if predictions match expected values within tolerance
96+
(
97+
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
98+
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
99+
)
100+
101+
# output
102+
103+
(true, true)
104+
```
105+
"""
106+
function DynamicPPL.predict(
107+
rng::DynamicPPL.Random.AbstractRNG,
108+
model::DynamicPPL.Model,
109+
chain::MCMCChains.Chains;
110+
include_all=false,
111+
)
112+
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
113+
varinfo = DynamicPPL.VarInfo(model)
114+
115+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116+
predictive_samples = map(iters) do (sample_idx, chain_idx)
117+
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118+
model(rng, varinfo, DynamicPPL.SampleFromPrior())
119+
120+
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121+
varname_vals = mapreduce(
122+
collect,
123+
vcat,
124+
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
125+
)
126+
127+
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
128+
end
129+
130+
chain_result = reduce(
131+
MCMCChains.chainscat,
132+
[
133+
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
134+
chain_idx in 1:size(predictive_samples, 2)
135+
],
136+
)
137+
parameter_names = if include_all
138+
MCMCChains.names(chain_result, :parameters)
139+
else
140+
filter(
141+
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
142+
names(chain_result, :parameters),
143+
)
144+
end
145+
return chain_result[parameter_names]
146+
end
147+
148+
function _predictive_samples_to_arrays(predictive_samples)
149+
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
150+
151+
sample_dicts = map(predictive_samples) do sample
152+
varname_value_pairs = sample.varname_and_values
153+
varnames = map(first, varname_value_pairs)
154+
values = map(last, varname_value_pairs)
155+
for varname in varnames
156+
push!(variable_names_set, varname)
157+
end
158+
159+
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
160+
end
161+
162+
variable_names = collect(variable_names_set)
163+
variable_values = [
164+
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
165+
key in variable_names
166+
]
167+
168+
return variable_names, variable_values
169+
end
170+
171+
function _predictive_samples_to_chains(predictive_samples)
172+
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
173+
variable_names_symbols = map(Symbol, variable_names)
174+
175+
internal_parameters = [:lp]
176+
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
177+
178+
parameter_names = [variable_names_symbols; internal_parameters]
179+
parameter_values = hcat(variable_values, log_probabilities)
180+
parameter_values = MCMCChains.concretize(parameter_values)
181+
182+
return MCMCChains.Chains(
183+
parameter_values, parameter_names, (internals=internal_parameters,)
184+
)
185+
end
186+
45187
"""
46188
returned(model::Model, chain::MCMCChains.Chains)
47189

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using AbstractPPL
55
using Bijectors
66
using Compat
77
using Distributions
8-
using OrderedCollections: OrderedDict
8+
using OrderedCollections: OrderedCollections, OrderedDict
99

1010
using AbstractMCMC: AbstractMCMC
1111
using ADTypes: ADTypes
@@ -40,6 +40,8 @@ import Base:
4040
keys,
4141
haskey
4242

43+
import AbstractPPL: predict
44+
4345
# VarInfo
4446
export AbstractVarInfo,
4547
VarInfo,

0 commit comments

Comments
 (0)