Skip to content

Commit f4db67a

Browse files
committed
Merge branch 'main' into breaking
2 parents 5a9e9d2 + 0cf3440 commit f4db67a

File tree

4 files changed

+47
-17
lines changed

4 files changed

+47
-17
lines changed

HISTORY.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.37.1
4+
5+
Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147.
6+
37
## 0.37.0
48

59
DynamicPPL 0.37 comes with a substantial reworking of its internals.
@@ -25,9 +29,14 @@ Please see the API documentation for more details.
2529

2630
There is now also an `rng` keyword argument to help seed parameter generation.
2731

28-
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
32+
Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
2933
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
3034

35+
Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain (respectively) the time it took to calculate the gradient of logp, and the time taken to calculate logp itself.
36+
Times are reported in seconds.
37+
Previously there was only a single `time_vs_primal` field which represented the ratio of these two.
38+
You can of course access the same quantity by dividing `grad_time` by `primal_time`.
39+
3140
### `DynamicPPL.TestUtils.check_model`
3241

3342
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
@@ -95,7 +104,7 @@ Because this is one of the more arcane features of DynamicPPL, some extra explan
95104
For example, the particle Gibbs method has a _reference particle_, for which variables are never resampled.
96105
However, if the reference particle is _forked_ (i.e., if the reference particle is selected by a resampling step multiple times and thereby copied), then the variables that have not yet been evaluated must be sampled anew to ensure that the new particle is independent of the reference particle.
97106

98-
Previousy, this was accomplished by setting the `del` flag in the `VarInfo` object for all variables with `order` greater or equal to than `num_produce`.
107+
Previously, this was accomplished by setting the `del` flag in the `VarInfo` object for all variables with `order` greater or equal to than `num_produce`.
99108
Note that setting the `del` flag does not itself trigger a new value to be sampled; rather, it indicates that a new value should be sampled _if the variable is encountered again_.
100109
[This Turing.jl PR](https://github.com/TuringLang/Turing.jl/pull/2629) changes the implementation to set the `del` flag for _all_ variables in the `VarInfo`.
101110
Since the `del` flag only makes a difference when encountering a variable, this approach is entirely equivalent as long as the same variable is not seen multiple times in the model.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ LinearAlgebra = "1.6"
6767
LogDensityProblems = "2"
6868
MCMCChains = "6, 7"
6969
MacroTools = "0.5.6"
70-
Mooncake = "0.4.95"
70+
Mooncake = "0.4.147"
7171
OrderedCollections = "1"
7272
Printf = "1.10"
7373
Random = "1.6"

ext/DynamicPPLMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ using DynamicPPL: DynamicPPL, istrans
44
using Mooncake: Mooncake
55

66
# This is purely an optimisation.
7-
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
7+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
88

99
end # module

src/test_utils/ad.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
109109
value_actual::Tresult
110110
"The gradient of logp (calculated using `adtype`)"
111111
grad_actual::Vector{Tresult}
112-
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
113-
time_vs_primal::Union{Nothing,Tresult}
112+
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
113+
of logp"
114+
grad_time::Union{Nothing,Tresult}
115+
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
116+
primal_time::Union{Nothing,Tresult}
114117
end
115118

116119
"""
@@ -121,6 +124,8 @@ end
121124
benchmark=false,
122125
atol::AbstractFloat=1e-8,
123126
rtol::AbstractFloat=sqrt(eps()),
127+
getlogdensity::Function=getlogjoint_internal,
128+
rng::Random.AbstractRNG=Random.default_rng(),
124129
varinfo::AbstractVarInfo=link(VarInfo(model), model),
125130
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
126131
verbose=true,
@@ -143,7 +148,7 @@ ReverseDiff`.
143148
There are two positional arguments, which absolutely must be provided:
144149
145150
1. `model` - The model being tested.
146-
2. `adtype` - The AD backend being tested.
151+
1. `adtype` - The AD backend being tested.
147152
148153
Everything else is optional, and can be categorised into several groups:
149154
@@ -156,7 +161,7 @@ Everything else is optional, and can be categorised into several groups:
156161
means that the parameters in the VarInfo have been transformed to
157162
unconstrained Euclidean space if they aren't already in that space.
158163
159-
2. _How to specify the parameters._
164+
1. _How to specify the parameters._
160165
161166
For maximum control over this, generate a vector of parameters yourself and
162167
pass this as the `params` argument. If you don't specify this, it will be
@@ -174,7 +179,18 @@ Everything else is optional, and can be categorised into several groups:
174179
prep_params)`. You could then evaluate the gradient at a different set of
175180
parameters using the `params` keyword argument.
176181
177-
3. _How to specify the results to compare against._
182+
1. _Which type of logp is being calculated._
183+
By default, `run_ad` evaluates the 'internal log joint density' of the model,
184+
i.e., the log joint density in the unconstrained space. Thus, for example, in
185+
@model f() = x ~ LogNormal()
186+
the internal log joint density is `logpdf(Normal(), log(x))`. This is the
187+
relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore
188+
the most useful to test.
189+
If you want the log joint density in the original model parameterisation, you
190+
can use `getlogjoint`. Likewise, if you want only the prior or likelihood,
191+
you can use `getlogprior` or `getloglikelihood`, respectively.
192+
193+
1. _How to specify the results to compare against._
178194
179195
Once logp and its gradient has been calculated with the specified `adtype`,
180196
it can optionally be tested for correctness. The exact way this is tested
@@ -192,7 +208,7 @@ Everything else is optional, and can be categorised into several groups:
192208
- `test=false` and `test=true` are synonyms for
193209
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
194210
195-
4. _How to specify the tolerances._ (Only if testing is enabled.)
211+
1. _How to specify the tolerances._ (Only if testing is enabled.)
196212
197213
Both absolute and relative tolerances can be specified using the `atol` and
198214
`rtol` keyword arguments respectively. The behaviour of these is similar to
@@ -204,7 +220,7 @@ Everything else is optional, and can be categorised into several groups:
204220
we cannot know the magnitude of logp and its gradient a priori. The `atol`
205221
value is supplied to handle the case where gradients are equal to zero.
206222
207-
5. _Whether to output extra logging information._
223+
1. _Whether to output extra logging information._
208224
209225
By default, this function prints messages when it runs. To silence it, set
210226
`verbose=false`.
@@ -277,14 +293,18 @@ function run_ad(
277293
end
278294

279295
# Benchmark
280-
time_vs_primal = if benchmark
296+
grad_time, primal_time = if benchmark
281297
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
282298
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
283-
t = median(grad_benchmark).time / median(primal_benchmark).time
284-
verbose && println("grad / primal : $(t)")
285-
t
299+
median_primal = median(primal_benchmark).time
300+
median_grad = median(grad_benchmark).time
301+
r(f) = round(f; sigdigits=4)
302+
verbose && println(
303+
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
304+
)
305+
(median_grad, median_primal)
286306
else
287-
nothing
307+
nothing, nothing
288308
end
289309

290310
return ADResult(
@@ -299,7 +319,8 @@ function run_ad(
299319
grad_true,
300320
value,
301321
grad,
302-
time_vs_primal,
322+
grad_time,
323+
primal_time,
303324
)
304325
end
305326

0 commit comments

Comments
 (0)