Skip to content

Commit 90c8a08

Browse files
committed
Merge branch 'breaking' into mhauru/accumulators-stage2
2 parents f2e676b + f4dd46a commit 90c8a08

25 files changed

+704
-526
lines changed

HISTORY.md

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,57 @@
22

33
## 0.37.0
44

5-
**Breaking changes**
5+
DynamicPPL 0.37 comes with a substantial reworking of its internals.
6+
Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release is unlikely to affect you much.
7+
However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes.
8+
9+
To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail.
10+
11+
Note that virtually all changes listed here are breaking.
12+
13+
**Public-facing changes**
614

715
### Submodel macro
816

917
The `@submodel` macro is fully removed; please use `to_submodel` instead.
1018

19+
### `DynamicPPL.TestUtils.AD.run_ad`
20+
21+
The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument.
22+
Please see the API documentation for more details.
23+
(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.)
24+
25+
There is now also an `rng` keyword argument to help seed parameter generation.
26+
27+
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
28+
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
29+
30+
### `DynamicPPL.TestUtils.check_model`
31+
32+
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
33+
Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`).
34+
35+
### Removal of `PriorContext` and `LikelihoodContext`
36+
37+
A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`.
38+
Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below.
39+
40+
Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field.
41+
`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively.
42+
43+
Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below).
44+
If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood).
45+
46+
If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`.
47+
`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want.
48+
Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`.
49+
50+
The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior.
51+
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.
52+
Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood.
53+
54+
**Internals**
55+
1156
### Accumulators
1257

1358
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:
@@ -59,6 +104,18 @@ And a couple of more internal changes:
59104
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument
60105
- The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched.
61106

107+
## 0.36.15
108+
109+
Bumped minimum Julia version to 1.10.8 to avoid potential crashes with `Core.Compiler.widenconst` (which Mooncake uses).
110+
111+
## 0.36.14
112+
113+
Added compatibility with [email protected].
114+
115+
## 0.36.13
116+
117+
Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method.
118+
62119
## 0.36.12
63120

64121
Removed several unexported functions.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4747
[compat]
4848
ADTypes = "1"
4949
AbstractMCMC = "5"
50-
AbstractPPL = "0.11"
50+
AbstractPPL = "0.11, 0.12"
5151
Accessors = "0.1"
5252
BangBang = "0.4.1"
5353
Bijectors = "0.13.18, 0.14, 0.15"
@@ -74,4 +74,4 @@ Random = "1.6"
7474
Requires = "1"
7575
Statistics = "1"
7676
Test = "1.6"
77-
julia = "1.10"
77+
julia = "1.10.8"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1414
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1515

1616
[compat]
17-
AbstractPPL = "0.11"
17+
AbstractPPL = "0.11, 0.12"
1818
Accessors = "0.1"
1919
DataStructures = "0.18"
2020
Distributions = "0.25"

docs/src/api.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
160160
@addlogprob!
161161
```
162162

163-
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).
163+
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.
164164

165165
```@docs
166+
returned(::DynamicPPL.Model, ::MCMCChains.Chains)
166167
returned(::DynamicPPL.Model, ::NamedTuple)
167168
```
168169

@@ -205,6 +206,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
205206

206207
```@docs
207208
DynamicPPL.TestUtils.AD.run_ad
209+
```
210+
211+
The default test setting is to compare against ForwardDiff.
212+
You can have more fine-grained control over how to test the AD backend using the following types:
213+
214+
```@docs
215+
DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting
216+
DynamicPPL.TestUtils.AD.WithBackend
217+
DynamicPPL.TestUtils.AD.WithExpectedResult
218+
DynamicPPL.TestUtils.AD.NoTest
219+
```
220+
221+
These are returned / thrown by the `run_ad` function:
222+
223+
```@docs
208224
DynamicPPL.TestUtils.AD.ADResult
209225
DynamicPPL.TestUtils.AD.ADIncorrectException
210226
```
@@ -325,7 +341,7 @@ get_num_produce
325341
set_num_produce!!
326342
increment_num_produce!!
327343
reset_num_produce!!
328-
setorder!
344+
setorder!!
329345
set_retained_vns_del!
330346
```
331347

@@ -352,7 +368,7 @@ DynamicPPL provides the following default accumulators.
352368
```@docs
353369
LogPriorAccumulator
354370
LogLikelihoodAccumulator
355-
NumProduceAccumulator
371+
VariableOrderAccumulator
356372
```
357373

358374
### Common API

src/DynamicPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export AbstractVarInfo,
5050
AbstractAccumulator,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
53-
NumProduceAccumulator,
53+
VariableOrderAccumulator,
5454
push!!,
5555
empty!!,
5656
subset,
@@ -73,7 +73,7 @@ export AbstractVarInfo,
7373
is_flagged,
7474
set_flag!,
7575
unset_flag!,
76-
setorder!,
76+
setorder!!,
7777
istrans,
7878
link,
7979
link!!,

src/abstract_varinfo.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo)
374374
return vi
375375
end
376376

377+
"""
378+
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
379+
380+
Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
381+
statements run before sampling `vn`.
382+
"""
383+
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
384+
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
385+
end
386+
387+
"""
388+
getorder(vi::VarInfo, vn::VarName)
389+
390+
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
391+
run before sampling `vn`.
392+
"""
393+
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]
394+
377395
# Variables and their realizations.
378396
@doc """
379397
keys(vi::AbstractVarInfo)
@@ -725,7 +743,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
725743
"""
726744
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
727745
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
728-
return !isempty(vns) && all(Base.Fix1(istrans, vi), vns)
746+
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
747+
# In theory that should work perfectly fine. For unbeknownst reasons,
748+
# Julia 1.10 fails to infer its return type correctly. Thus we use this
749+
# slightly longer definition.
750+
isempty(vns) && return false
751+
for vn in vns
752+
istrans(vi, vn) || return false
753+
end
754+
return true
729755
end
730756

731757
"""
@@ -972,29 +998,37 @@ end
972998
973999
Return the `num_produce` of `vi`.
9741000
"""
975-
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
1001+
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce
9761002

9771003
"""
9781004
set_num_produce!!(vi::AbstractVarInfo, n::Int)
9791005
9801006
Set the `num_produce` field of `vi` to `n`.
9811007
"""
982-
set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
1008+
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
1009+
if hasacc(vi, Val(:VariableOrder))
1010+
acc = getacc(vi, Val(:VariableOrder))
1011+
acc = VariableOrderAccumulator(n, acc.order)
1012+
else
1013+
acc = VariableOrderAccumulator(n)
1014+
end
1015+
return setacc!!(vi, acc)
1016+
end
9831017

9841018
"""
9851019
increment_num_produce!!(vi::AbstractVarInfo)
9861020
9871021
Add 1 to `num_produce` in `vi`.
9881022
"""
9891023
increment_num_produce!!(vi::AbstractVarInfo) =
990-
map_accumulator!!(increment, vi, Val(:NumProduce))
1024+
map_accumulator!!(increment, vi, Val(:VariableOrder))
9911025

9921026
"""
9931027
reset_num_produce!!(vi::AbstractVarInfo)
9941028
9951029
Reset the value of `num_produce` in `vi` to 0.
9961030
"""
997-
reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
1031+
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))
9981032

9991033
"""
10001034
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])

src/accumulators.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
1414
- `accumulate_observe!!(acc::T, right, left, vn)`
1515
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
16+
- `Base.copy(acc::T)`
1617
1718
To be able to work with multi-threading, it should also implement:
1819
- `split(acc::T)`
@@ -53,10 +54,11 @@ function accumulate_observe!! end
5354
5455
Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`.
5556
56-
`vn` is the name of the variable being assumed, `val` is the value of the variable, and
57-
`right` is the distribution on the RHS of the tilde statement. `logjac` is the log
58-
determinant of the Jacobian of the transformation that was done to convert the value of `vn`
59-
as it was given (e.g. by sampler operating in linked space) to `val`.
57+
`vn` is the name of the variable being assumed, `val` is the value of the variable (in the
58+
original, unlinked space), and `right` is the distribution on the RHS of the tilde
59+
statement. `logjac` is the log determinant of the Jacobian of the transformation that was
60+
done to convert the value of `vn` as it was given to `val`: for example, if the sampler is
61+
operating in linked (Euclidean) space, then logjac will be nonzero.
6062
6163
`accumulate_assume!!` may mutate `acc`, but not any of the other arguments.
6264
@@ -71,7 +73,7 @@ Return a new accumulator like `acc` but empty.
7173
7274
The precise meaning of "empty" is that that the returned value should be such that
7375
`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading
74-
where different threads may accumulate independently and the results are the combined.
76+
where different threads may accumulate independently and the results are then combined.
7577
7678
See also: [`combine`](@ref)
7779
"""
@@ -80,7 +82,8 @@ function split end
8082
"""
8183
combine(acc::AbstractAccumulator, acc2::AbstractAccumulator)
8284
83-
Combine two accumulators of the same type. Returns a new accumulator.
85+
Combine two accumulators which have the same type (but may, in general, have different type
86+
parameters). Returns a new accumulator of the same type.
8487
8588
See also: [`split`](@ref)
8689
"""
@@ -136,6 +139,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
136139
@inline return haskey(at.nt, accname)
137140
end
138141
Base.keys(at::AccumulatorTuple) = keys(at.nt)
142+
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
143+
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
144+
Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt))
139145

140146
function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
141147
return AccumulatorTuple(convert(T, accs.nt))

src/context_implementations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ function assume(
148148
f = to_maybe_linked_internal_transform(vi, vn, dist)
149149
# TODO(mhauru) This should probably be call a function called setindex_internal!
150150
vi = BangBang.setindex!!(vi, f(r), vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152151
else
153152
# Otherwise we just extract it.
154153
r = vi[vn, dist]

0 commit comments

Comments
 (0)