Skip to content

Commit 8d460c0

Browse files
committed
Make run_ad return both primal and grad
1 parent 5d9e934 commit 8d460c0

File tree

4 files changed

+66
-80
lines changed

4 files changed

+66
-80
lines changed

HISTORY.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ Please see the API documentation for more details.
2424

2525
There is now also an `rng` keyword argument to help seed parameter generation.
2626

27-
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
27+
Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
2828
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
2929

30+
Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain the time it took to calculate the gradient of logp and logp itself.
31+
Previously there was only a single `time_vs_primal` field which represented the ratio of these two.
32+
3033
### `DynamicPPL.TestUtils.check_model`
3134

3235
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.

benchmarks/benchmarks.jl

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using Pkg
2-
3-
using DynamicPPLBenchmarks: Models, make_suite, model_dimension
4-
using BenchmarkTools: @benchmark, median, run
1+
using DynamicPPLBenchmarks: Models, to_backend, make_varinfo, model_dimension
2+
using DynamicPPL.TestUtils.AD: run_ad, NoTest
3+
using Chairmarks: @be
54
using PrettyTables: PrettyTables, ft_printf
65
using StableRNGs: StableRNG
76

@@ -35,48 +34,45 @@ chosen_combinations = [
3534
Models.simple_assume_observe(randn(rng)),
3635
:typed,
3736
:forwarddiff,
38-
false,
3937
),
40-
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
41-
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
42-
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
43-
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
44-
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
45-
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
46-
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
47-
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
48-
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
49-
("Multivariate 10k", multivariate10k, :typed, :mooncake, true),
50-
("Dynamic", Models.dynamic(), :typed, :mooncake, true),
51-
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true),
52-
("LDA", lda_instance, :typed, :reversediff, true),
38+
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff),
39+
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff),
40+
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff),
41+
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff),
42+
("Smorgasbord", smorgasbord_instance, :typed, :reversediff),
43+
("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
44+
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake),
45+
("Multivariate 1k", multivariate1k, :typed, :mooncake),
46+
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake),
47+
("Multivariate 10k", multivariate10k, :typed, :mooncake),
48+
("Dynamic", Models.dynamic(), :typed, :mooncake),
49+
("Submodel", Models.parent(randn(rng)), :typed, :mooncake),
50+
("LDA", lda_instance, :typed, :reversediff),
5351
]
5452

5553
# Time running a model-like function that does not use DynamicPPL, as a reference point.
5654
# Eval timings will be relative to this.
5755
reference_time = begin
5856
obs = randn(rng)
59-
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
57+
median(@be Models.simple_assume_observe_non_model(obs)).time
6058
end
6159

6260
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
6361

6462
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
6563
@info "Running benchmark for $model_name"
66-
suite = make_suite(model, varinfo_choice, adbackend, islinked)
67-
results = run(suite)
68-
eval_time = median(results["evaluation"]).time
69-
relative_eval_time = eval_time / reference_time
70-
ad_eval_time = median(results["gradient"]).time
71-
relative_ad_eval_time = ad_eval_time / eval_time
64+
adtype = to_backend(adbackend)
65+
varinfo = make_varinfo(model, varinfo_choice)
66+
ad_result = run_ad(model, adtype; test=NoTest(), benchmark=true, varinfo=varinfo)
67+
relative_eval_time = ad_result.primal_time / reference_time
68+
relative_ad_eval_time = ad_result.grad_time / ad_result.primal_time
7269
push!(
7370
results_table,
7471
(
7572
model_name,
76-
model_dimension(model, islinked),
73+
length(varinfo[:]),
7774
string(adbackend),
7875
string(varinfo_choice),
79-
islinked,
8076
relative_eval_time,
8177
relative_ad_eval_time,
8278
),
@@ -89,14 +85,13 @@ header = [
8985
"Dimension",
9086
"AD Backend",
9187
"VarInfo Type",
92-
"Linked",
9388
"Eval Time / Ref Time",
9489
"AD Time / Eval Time",
9590
]
9691
PrettyTables.pretty_table(
9792
table_matrix;
9893
header=header,
9994
tf=PrettyTables.tf_markdown,
100-
formatters=ft_printf("%.1f", [6, 7]),
95+
formatters=ft_printf("%.1f", [5, 6]),
10196
crop=:none, # Always print the whole table, even if it doesn't fit in the terminal.
10297
)

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,7 @@ using StableRNGs: StableRNG
1414
include("./Models.jl")
1515
using .Models: Models
1616

17-
export Models, make_suite, model_dimension
18-
19-
"""
20-
model_dimension(model, islinked)
21-
22-
Return the dimension of `model`, accounting for linking, if any.
23-
"""
24-
function model_dimension(model, islinked)
25-
vi = VarInfo()
26-
model(StableRNG(23), vi)
27-
if islinked
28-
vi = DynamicPPL.link(vi, model)
29-
end
30-
return length(vi[:])
31-
end
17+
export Models, to_backend, make_varinfo
3218

3319
# Utility functions for representing AD backends using symbols.
3420
# Copied from TuringBenchmarking.jl.
@@ -48,24 +34,20 @@ function to_backend(x::Union{AbstractString,Symbol})
4834
end
4935

5036
"""
51-
make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
37+
make_varinfo(model, varinfo_choice::Symbol)
5238
53-
Create a benchmark suite for `model` using the selected varinfo type and AD backend.
39+
Create a VarInfo for the given `model` using the selected varinfo type.
5440
Available varinfo choices:
5541
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
5642
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
5743
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
5844
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
5945
60-
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
61-
62-
`islinked` determines whether to link the VarInfo for evaluation.
46+
The VarInfo is always linked.
6347
"""
64-
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
48+
function make_varinfo(model::Model, varinfo_choice::Symbol, adbackend::Symbol)
6549
rng = StableRNG(23)
6650

67-
suite = BenchmarkGroup()
68-
6951
vi = if varinfo_choice == :untyped
7052
DynamicPPL.untyped_varinfo(rng, model)
7153
elseif varinfo_choice == :typed
@@ -80,26 +62,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8062
error("Unknown varinfo choice: $varinfo_choice")
8163
end
8264

83-
adbackend = to_backend(adbackend)
84-
85-
if islinked
86-
vi = DynamicPPL.link(vi, model)
87-
end
88-
89-
f = DynamicPPL.LogDensityFunction(
90-
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
91-
)
92-
# The parameters at which we evaluate f.
93-
θ = vi[:]
94-
95-
# Run once to trigger compilation.
96-
LogDensityProblems.logdensity_and_gradient(f, θ)
97-
suite["gradient"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ)
98-
99-
# Also benchmark just standard model evaluation because why not.
100-
suite["evaluation"] = @benchmarkable $(LogDensityProblems.logdensity)($f, $θ)
101-
102-
return suite
65+
return DynamicPPL.link(vi, model)
10366
end
10467

10568
end # module

src/test_utils/ad.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
109109
value_actual::Tresult
110110
"The gradient of logp (calculated using `adtype`)"
111111
grad_actual::Vector{Tresult}
112-
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
113-
time_vs_primal::Union{Nothing,Tresult}
112+
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
113+
of logp"
114+
grad_time::Union{Nothing,Tresult}
115+
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
116+
primal_time::Union{Nothing,Tresult}
114117
end
115118

116119
"""
@@ -121,6 +124,8 @@ end
121124
benchmark=false,
122125
atol::AbstractFloat=1e-8,
123126
rtol::AbstractFloat=sqrt(eps()),
127+
getlogdensity::Function=getlogjoint_internal,
128+
rng::AbstractRNG=default_rng(),
124129
varinfo::AbstractVarInfo=link(VarInfo(model), model),
125130
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
126131
verbose=true,
@@ -174,6 +179,21 @@ Everything else is optional, and can be categorised into several groups:
174179
prep_params)`. You could then evaluate the gradient at a different set of
175180
parameters using the `params` keyword argument.
176181
182+
3. _Which type of logp is being calculated._
183+
184+
By default, `run_ad` evaluates the 'internal log joint density' of the model,
185+
i.e., the log joint density in the unconstrained space. Thus, for example, in
186+
187+
@model f() = x ~ LogNormal()
188+
189+
the internal log joint density is `logpdf(Normal(), log(x))`. This is the
190+
relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore
191+
the most useful to test.
192+
193+
If you want the log joint density in the original model parameterisation, you
194+
can use `getlogjoint`. Likewise, if you want only the prior or likelihood,
195+
you can use `getlogprior` or `getloglikelihood`, respectively.
196+
177197
3. _How to specify the results to compare against._
178198
179199
Once logp and its gradient has been calculated with the specified `adtype`,
@@ -277,12 +297,16 @@ function run_ad(
277297
end
278298

279299
# Benchmark
280-
time_vs_primal = if benchmark
300+
grad_time, primal_time = if benchmark
281301
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
282302
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
283-
t = median(grad_benchmark).time / median(primal_benchmark).time
284-
verbose && println("grad / primal : $(t)")
285-
t
303+
median_primal = median(primal_benchmark).time
304+
median_grad = median(grad_benchmark).time
305+
r(f) = round(f; sigdigits=4)
306+
verbose && println(
307+
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
308+
)
309+
(median_grad, median_primal)
286310
else
287311
nothing
288312
end
@@ -299,7 +323,8 @@ function run_ad(
299323
grad_true,
300324
value,
301325
grad,
302-
time_vs_primal,
326+
grad_time,
327+
primal_time,
303328
)
304329
end
305330

0 commit comments

Comments
 (0)