Skip to content

Commit ee863d6

Browse files
penelopeysmmhauru
andauthored
v0.39 (#1082)
* v0.39 * Update DPPL compats for benchmarks and docs * remove merge conflict markers * Remove `NodeTrait` (#1133) * Remove NodeTrait * Changelog * Fix exports * docs * fix a bug * Fix doctests * Fix test * tweak changelog * FastLDF / InitContext unified (#1132) * Fast Log Density Function * Make it work with AD * Optimise performance for identity VarNames * Mark `get_range_and_linked` as having zero derivative * Update comment * make AD testing / benchmarking use FastLDF * Fix tests * Optimise away `make_evaluate_args_and_kwargs` * const func annotation * Disable benchmarks on non-typed-Metadata-VarInfo * Fix `_evaluate!!` correctly to handle submodels * Actually fix submodel evaluate * Document thoroughly and organise code * Support more VarInfos, make it thread-safe (?) * fix bug in parsing ranges from metadata/VNV * Fix get_param_eltype for TSVI * Disable Enzyme benchmark * Don't override _evaluate!!, that breaks ForwardDiff (sometimes) * Move FastLDF to experimental for now * Fix imports, add tests, etc * More test fixes * Fix imports / tests * Remove AbstractFastEvalContext * Changelog and patch bump * Add correctness tests, fix imports * Concretise parameter vector in tests * Add zero-allocation tests * Add Chairmarks as test dep * Disable allocations tests on multi-threaded * Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation * Refactor FastLDF to use InitContext * note init breaking change * fix logjac sign * workaround Mooncake segfault * fix changelog too * Fix get_param_eltype for context stacks * Add a test for threaded observe * Export init * Remove dead code * fix transforms for pathological distributions * Tidy up loads of things * fix typed_identity spelling * fix definition order * Improve docstrings * Remove stray comment * export get_param_eltype (unfortunatley) * Add more comment * Update comment * Remove inlines, fix OAVI docstring * Improve docstrings * Simplify InitFromParams constructor * Replace map(identity, x[:]) with [i for i in x[:]] * Simplify implementation for InitContext/OAVI * Add another model to allocation tests Co-authored-by: Markus Hauru <[email protected]> * Revert removal of dist argument (oops) * Format * Update some outdated bits of FastLDF docstring * remove underscores --------- Co-authored-by: Markus Hauru <[email protected]> * implement `LogDensityProblems.dimension` * forgot about capabilities... * use interpolation in run_ad * Improvements to benchmark outputs (#1146) * print output * fix * reenable * add more lines to guide the eye * reorder table * print tgrad / trel as well * forgot this type * Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129) * Implement `ParamsWithStats` for `FastLDF` * Add comments * Implement `bundle_samples` for ParamsWithStats -> MCMCChains * Remove redundant comment * don't need Statistics? * Make FastLDF the default (#1139) * Make FastLDF the default * Add miscellaneous LogDensityProblems tests * Use `init!!` instead of `fast_evaluate!!` * Rename files, rebalance tests * Implement `predict`, `returned`, `logjoint`, ... with `OnlyAccsVarInfo` (#1130) * Use OnlyAccsVarInfo for many re-evaluation functions * drop `fast_` prefix * Add a changelog * Improve FastLDF type stability when all parameters are linked or unlinked (#1141) * Improve type stability when all parameters are linked or unlinked * fix a merge conflict * fix enzyme gc crash (locally at least) * Fixes from review * Make threadsafe evaluation opt-in (#1151) * Make threadsafe evaluation opt-in * Reduce number of type parameters in methods * Make `warned_warn_about_threads_threads_threads_threads` shorter * Improve `setthreadsafe` docstring * warn on bare `@threads` as well * fix merge * Fix performance issues * Use maxthreadid() in TSVI * Move convert_eltype code to threadsafe eval function * Point to new Turing docs page * Add a test for setthreadsafe * Tidy up check_model * Apply suggestions from code review Fix outdated docstrings Co-authored-by: Markus Hauru <[email protected]> * Improve warning message * Export `requires_threadsafe` * Add an actual docstring for `requires_threadsafe` --------- Co-authored-by: Markus Hauru <[email protected]> * Standardise `:lp` -> `:logjoint` (#1161) * Standardise `:lp` -> `:logjoint` * changelog * fix a test --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent 6c615ad commit ee863d6

39 files changed

+1656
-1069
lines changed

HISTORY.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,86 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.0
4+
5+
### Breaking changes
6+
7+
#### Fast Log Density Functions
8+
9+
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
10+
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
11+
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments.
13+
14+
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
15+
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
16+
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
17+
18+
#### Threadsafe evaluation
19+
20+
DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel.
21+
Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread.
22+
23+
In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**.
24+
If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using:
25+
26+
```julia
27+
@model f() = ...
28+
model = f()
29+
model = setthreadsafe(model, true)
30+
```
31+
32+
The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed.
33+
This is especially true when running Julia in a notebook, where multiple threads are often enabled by default.
34+
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.
35+
36+
**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
37+
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:
38+
39+
- tilde-statements
40+
- calls to `@addlogprob!`
41+
- any direct manipulation of the special `__varinfo__` variable
42+
43+
If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
44+
**Notably, the following do not require threadsafe evaluation:**
45+
46+
- Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
47+
- Sampling with `AbstractMCMC.MCMCThreads()`.
48+
49+
For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/).
50+
51+
When threadsafe evaluation is enabled for a model, an internal flag is set on the model.
52+
The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean.
53+
This function is newly exported in this version of DynamicPPL.
54+
55+
#### Parent and leaf contexts
56+
57+
The `DynamicPPL.NodeTrait` function has been removed.
58+
Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`.
59+
This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`.
60+
61+
There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users.
62+
63+
Leaf contexts require no changes, apart from a removal of the `NodeTrait` function.
64+
65+
`ConditionContext` and `PrefixContext` are no longer exported.
66+
You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead.
67+
68+
#### ParamsWithStats
69+
70+
In the 'stats' part of `DynamicPPL.ParamsWithStats`, the log-joint is now consistently represented with the key `logjoint` instead of `lp`.
71+
72+
#### Miscellaneous
73+
74+
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
75+
76+
The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed.
77+
78+
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
79+
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
80+
81+
The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown.
82+
Previously the behaviour would have been to sample missing variables.
83+
384
## 0.38.10
485

586
`returned(model, chain)` and `pointwise_logdensities(model, chain)` will now error if a value for a random variable cannot be found in the chain.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.10"
3+
version = "0.39.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

benchmarks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ DynamicPPL = {path = "../"}
2424
ADTypes = "1.14.0"
2525
Chairmarks = "1.3.1"
2626
Distributions = "0.25.117"
27-
DynamicPPL = "0.38"
27+
DynamicPPL = "0.39"
2828
Enzyme = "0.13"
2929
ForwardDiff = "1"
3030
JSON = "1.3.0"

benchmarks/benchmarks.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,15 @@ function run(; to_json=false)
9898
}[]
9999

100100
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
101-
@info "Running benchmark for $model_name"
101+
@info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked"
102102
relative_eval_time, relative_ad_eval_time = try
103103
results = benchmark(model, varinfo_choice, adbackend, islinked)
104+
@info " t(eval) = $(results.primal_time)"
105+
@info " t(grad) = $(results.grad_time)"
104106
(results.primal_time / reference_time),
105107
(results.grad_time / results.primal_time)
106108
catch e
109+
@info "benchmark errored: $e"
107110
missing, missing
108111
end
109112
push!(
@@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String)
155158
all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases)))
156159
@info "$(length(all_testcases)) unique test cases found"
157160
sorted_testcases = sort(
158-
collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked))
161+
collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend))
159162
)
160163
results_table = Tuple{
161-
String,Int,String,String,Bool,String,String,String,String,String,String
164+
String,
165+
Int,
166+
String,
167+
String,
168+
Bool,
169+
String,
170+
String,
171+
String,
172+
String,
173+
String,
174+
String,
175+
String,
176+
String,
177+
String,
162178
}[]
179+
sublabels = ["base", "this PR", "speedup"]
163180
results_colnames = [
164181
[
165182
EmptyCells(5),
166183
MultiColumn(3, "t(eval) / t(ref)"),
167184
MultiColumn(3, "t(grad) / t(eval)"),
185+
MultiColumn(3, "t(grad) / t(ref)"),
168186
],
169-
[colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"],
187+
[colnames[1:5]..., sublabels..., sublabels..., sublabels...],
170188
]
171189
sprint_float(x::Float64) = @sprintf("%.2f", x)
172190
sprint_float(m::Missing) = "err"
@@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String)
183201
# Finally that lets us do this division safely
184202
speedup_eval = base_eval / head_eval
185203
speedup_grad = base_grad / head_grad
204+
# As well as this multiplication, which is t(grad) / t(ref)
205+
head_grad_vs_ref = head_grad * head_eval
206+
base_grad_vs_ref = base_grad * base_eval
207+
speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref
186208
push!(
187209
results_table,
188210
(
@@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String)
197219
sprint_float(base_grad),
198220
sprint_float(head_grad),
199221
sprint_float(speedup_grad),
222+
sprint_float(base_grad_vs_ref),
223+
sprint_float(head_grad_vs_ref),
224+
sprint_float(speedup_grad_vs_ref),
200225
),
201226
)
202227
end
@@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String)
212237
backend=:text,
213238
fit_table_in_display_horizontally=false,
214239
fit_table_in_display_vertically=false,
215-
table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true),
240+
table_format=TextTableFormat(;
241+
horizontal_line_at_merged_column_labels=true,
242+
horizontal_lines_at_data_rows=collect(3:3:length(results_table)),
243+
),
216244
)
217245
println("```")
218246
end

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Accessors = "0.1"
2121
Distributions = "0.25"
2222
Documenter = "1"
2323
DocumenterMermaid = "0.1, 0.2"
24-
DynamicPPL = "0.38"
24+
DynamicPPL = "0.39"
2525
FillArrays = "0.13, 1"
2626
ForwardDiff = "0.10, 1"
2727
JET = "0.9, 0.10, 0.11"

docs/src/api.md

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref):
4242
contextualize
4343
```
4444

45+
Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary).
46+
If this is the case, one must enable threadsafe evaluation for a model:
47+
48+
```@docs
49+
setthreadsafe
50+
requires_threadsafe
51+
```
52+
4553
## Evaluation
4654

4755
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
@@ -66,6 +74,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6674
LogDensityFunction
6775
```
6876

77+
Internally, this is accomplished using [`init!!`](@ref) on:
78+
79+
```@docs
80+
OnlyAccsVarInfo
81+
```
82+
6983
## Condition and decondition
7084

7185
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
@@ -170,6 +184,12 @@ DynamicPPL.prefix
170184

171185
## Utilities
172186

187+
`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors.
188+
189+
```@docs
190+
typed_identity
191+
```
192+
173193
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
174194

175195
```@docs
@@ -352,13 +372,6 @@ Base.empty!
352372
SimpleVarInfo
353373
```
354374

355-
### Tilde-pipeline
356-
357-
```@docs
358-
tilde_assume!!
359-
tilde_observe!!
360-
```
361-
362375
### Accumulators
363376

364377
The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
@@ -463,22 +476,55 @@ By default, it does not perform any actual sampling: it only evaluates the model
463476
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.
464477

465478
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
466-
Contexts are subtypes of `AbstractPPL.AbstractContext`.
479+
480+
All contexts are subtypes of `AbstractPPL.AbstractContext`.
481+
482+
Contexts are split into two kinds:
483+
484+
**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds.
485+
For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters.
486+
DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported.
467487

468488
```@docs
469489
DefaultContext
470-
PrefixContext
471-
ConditionContext
472490
InitContext
473491
```
474492

493+
To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context.
494+
495+
```@docs
496+
tilde_assume!!
497+
tilde_observe!!
498+
```
499+
500+
**Parent contexts**: These essentially act as 'modifiers' for leaf contexts.
501+
For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed.
502+
503+
To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods.
504+
If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context.
505+
This is optional; the default implementation is to simply delegate to the child context.
506+
507+
```@docs
508+
AbstractParentContext
509+
childcontext
510+
setchildcontext
511+
```
512+
513+
Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks.
514+
They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts).
515+
516+
```@docs
517+
leafcontext
518+
setleafcontext
519+
```
520+
475521
### VarInfo initialisation
476522

477523
The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
478524
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.
479525

480526
```@docs
481-
DynamicPPL.init!!
527+
init!!
482528
```
483529

484530
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
@@ -491,10 +537,12 @@ InitFromParams
491537
```
492538

493539
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.
540+
In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy.
494541

495542
```@docs
496-
DynamicPPL.AbstractInitStrategy
497-
DynamicPPL.init
543+
AbstractInitStrategy
544+
init
545+
get_param_eltype
498546
```
499547

500548
### Choosing a suitable VarInfo

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
module DynamicPPLEnzymeCoreExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: DynamicPPL
5-
using EnzymeCore
6-
else
7-
using ..DynamicPPL: DynamicPPL
8-
using ..EnzymeCore
9-
end
3+
using DynamicPPL: DynamicPPL
4+
using EnzymeCore
105

116
# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
127
# only checks whether such a method exists, and never runs it.
138
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
149
nothing
10+
# Likewise for get_range_and_linked.
11+
@inline EnzymeCore.EnzymeRules.inactive(
12+
::typeof(DynamicPPL._get_range_and_linked), args...
13+
) = nothing
1514

1615
end

0 commit comments

Comments
 (0)