Skip to content

Commit 8fb70d4

Browse files
authored
Make run_ad return timings for both primal and gradient (#1009)
1 parent 084f5af commit 8fb70d4

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

HISTORY.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ Please see the API documentation for more details.
2525

2626
There is now also an `rng` keyword argument to help seed parameter generation.
2727

28-
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
28+
Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
2929
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
3030

31+
Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain (respectively) the time it took to calculate the gradient of logp, and the time taken to calculate logp itself.
32+
Times are reported in seconds.
33+
Previously there was only a single `time_vs_primal` field which represented the ratio of these two.
34+
You can of course access the same quantity by dividing `grad_time` by `primal_time`.
35+
3136
### `DynamicPPL.TestUtils.check_model`
3237

3338
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.

src/test_utils/ad.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
109109
value_actual::Tresult
110110
"The gradient of logp (calculated using `adtype`)"
111111
grad_actual::Vector{Tresult}
112-
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
113-
time_vs_primal::Union{Nothing,Tresult}
112+
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
113+
of logp"
114+
grad_time::Union{Nothing,Tresult}
115+
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
116+
primal_time::Union{Nothing,Tresult}
114117
end
115118

116119
"""
@@ -121,6 +124,8 @@ end
121124
benchmark=false,
122125
atol::AbstractFloat=1e-8,
123126
rtol::AbstractFloat=sqrt(eps()),
127+
getlogdensity::Function=getlogjoint_internal,
128+
rng::Random.AbstractRNG=Random.default_rng(),
124129
varinfo::AbstractVarInfo=link(VarInfo(model), model),
125130
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
126131
verbose=true,
@@ -143,7 +148,7 @@ ReverseDiff`.
143148
There are two positional arguments, which absolutely must be provided:
144149
145150
1. `model` - The model being tested.
146-
2. `adtype` - The AD backend being tested.
151+
1. `adtype` - The AD backend being tested.
147152
148153
Everything else is optional, and can be categorised into several groups:
149154
@@ -156,7 +161,7 @@ Everything else is optional, and can be categorised into several groups:
156161
means that the parameters in the VarInfo have been transformed to
157162
unconstrained Euclidean space if they aren't already in that space.
158163
159-
2. _How to specify the parameters._
164+
1. _How to specify the parameters._
160165
161166
For maximum control over this, generate a vector of parameters yourself and
162167
pass this as the `params` argument. If you don't specify this, it will be
@@ -174,7 +179,18 @@ Everything else is optional, and can be categorised into several groups:
174179
prep_params)`. You could then evaluate the gradient at a different set of
175180
parameters using the `params` keyword argument.
176181
177-
3. _How to specify the results to compare against._
182+
1. _Which type of logp is being calculated._
183+
By default, `run_ad` evaluates the 'internal log joint density' of the model,
184+
i.e., the log joint density in the unconstrained space. Thus, for example, in
185+
@model f() = x ~ LogNormal()
186+
the internal log joint density is `logpdf(Normal(), log(x))`. This is the
187+
relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore
188+
the most useful to test.
189+
If you want the log joint density in the original model parameterisation, you
190+
can use `getlogjoint`. Likewise, if you want only the prior or likelihood,
191+
you can use `getlogprior` or `getloglikelihood`, respectively.
192+
193+
1. _How to specify the results to compare against._
178194
179195
Once logp and its gradient has been calculated with the specified `adtype`,
180196
it can optionally be tested for correctness. The exact way this is tested
@@ -192,7 +208,7 @@ Everything else is optional, and can be categorised into several groups:
192208
- `test=false` and `test=true` are synonyms for
193209
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
194210
195-
4. _How to specify the tolerances._ (Only if testing is enabled.)
211+
1. _How to specify the tolerances._ (Only if testing is enabled.)
196212
197213
Both absolute and relative tolerances can be specified using the `atol` and
198214
`rtol` keyword arguments respectively. The behaviour of these is similar to
@@ -204,7 +220,7 @@ Everything else is optional, and can be categorised into several groups:
204220
we cannot know the magnitude of logp and its gradient a priori. The `atol`
205221
value is supplied to handle the case where gradients are equal to zero.
206222
207-
5. _Whether to output extra logging information._
223+
1. _Whether to output extra logging information._
208224
209225
By default, this function prints messages when it runs. To silence it, set
210226
`verbose=false`.
@@ -277,14 +293,18 @@ function run_ad(
277293
end
278294

279295
# Benchmark
280-
time_vs_primal = if benchmark
296+
grad_time, primal_time = if benchmark
281297
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
282298
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
283-
t = median(grad_benchmark).time / median(primal_benchmark).time
284-
verbose && println("grad / primal : $(t)")
285-
t
299+
median_primal = median(primal_benchmark).time
300+
median_grad = median(grad_benchmark).time
301+
r(f) = round(f; sigdigits=4)
302+
verbose && println(
303+
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
304+
)
305+
(median_grad, median_primal)
286306
else
287-
nothing
307+
nothing, nothing
288308
end
289309

290310
return ADResult(
@@ -299,7 +319,8 @@ function run_ad(
299319
grad_true,
300320
value,
301321
grad,
302-
time_vs_primal,
322+
grad_time,
323+
primal_time,
303324
)
304325
end
305326

0 commit comments

Comments
 (0)