Skip to content

Commit ee9a81f

Browse files
committed
Make benchmark times relative. Add benchmark documentation.
1 parent 0f7c924 commit ee9a81f

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

benchmarks/benchmarks.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using DynamicPPLBenchmarks: Models, make_suite
2-
using BenchmarkTools: median, run
3-
using PrettyTables: PrettyTables
2+
using BenchmarkTools: @benchmark, median, run
3+
using PrettyTables: PrettyTables, ft_printf
44
using Random: seed!
55

66
seed!(23)
77

8+
# Create DynamicPPL.Model instances to run benchmarks on.
89
smorgasbord_instance = Models.smorgasbord(randn(100), randn(100))
910
loop_univariate1k, multivariate1k = begin
1011
data_1k = randn(1_000)
@@ -33,6 +34,8 @@ chosen_combinations = [
3334
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff),
3435
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff),
3536
("Smorgasbord", smorgasbord_instance, :typed, :reversediff),
37+
# TODO(mhauru) Add Mooncake once TuringBenchmarking.jl supports it. Consider changing
38+
# all the below :reversediffs to :mooncakes too.
3639
#("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
3740
("Loop univariate 1k", loop_univariate1k, :typed, :reversediff),
3841
("Multivariate 1k", multivariate1k, :typed, :reversediff),
@@ -43,30 +46,54 @@ chosen_combinations = [
4346
("LDA", lda_instance, :typed, :reversediff),
4447
]
4548

49+
# Time running a model-like function that does not use DynamicPPL, as a reference point.
50+
# Eval timings will be relative to this.
51+
reference_time = begin
52+
obs = randn()
53+
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
54+
end
55+
4656
results_table = Tuple{String,String,String,Float64,Float64}[]
4757

4858
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
4959
suite = make_suite(model, varinfo_choice, adbackend)
5060
results = run(suite)
5161

5262
eval_time = median(results["evaluation"]["standard"]).time
63+
relative_eval_time = eval_time / reference_time
5364

5465
grad_group = results["gradient"]
5566
if isempty(grad_group)
56-
ad_eval_time = NaN
67+
relative_ad_eval_time = NaN
5768
else
5869
grad_backend_key = first(keys(grad_group))
5970
ad_eval_time = median(grad_group[grad_backend_key]["standard"]).time
71+
relative_ad_eval_time = ad_eval_time / eval_time
6072
end
6173

6274
push!(
6375
results_table,
64-
(model_name, string(adbackend), string(varinfo_choice), eval_time, ad_eval_time),
76+
(
77+
model_name,
78+
string(adbackend),
79+
string(varinfo_choice),
80+
relative_eval_time,
81+
relative_ad_eval_time,
82+
),
6583
)
6684
end
6785

6886
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
6987
header = [
70-
"Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)"
88+
"Model",
89+
"AD Backend",
90+
"VarInfo Type",
91+
"Evaluation Time / Reference Time",
92+
"AD Time / Eval Time",
7193
]
72-
PrettyTables.pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown)
94+
PrettyTables.pretty_table(
95+
table_matrix;
96+
header=header,
97+
tf=PrettyTables.tf_markdown,
98+
formatters=ft_printf("%.1f", [4, 5]),
99+
)

benchmarks/src/Models.jl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,28 @@ export simple_assume_observe_non_model,
2424

2525
# This one is like simple_assume_observe, but explicitly does not use DynamicPPL.
2626
# Other runtimes are normalised by this one's runtime.
27-
function simple_assume_observe_non_model(x, obs)
28-
logp = logdf(x, Normal())
29-
logp += logpdf(obs, Normal(x, 1))
30-
return logp
27+
function simple_assume_observe_non_model(obs)
28+
x = rand(Normal())
29+
logp = logpdf(Normal(), x)
30+
logp += logpdf(Normal(x, 1), obs)
31+
return (; logp=logp, x=x)
3132
end
3233

34+
"""
35+
A simple model that does one scalar assumption and one scalar observation.
36+
"""
3337
@model function simple_assume_observe(obs)
3438
x ~ Normal()
3539
obs ~ Normal(x, 1)
3640
return (; x=x)
3741
end
3842

43+
"""
44+
A short model that tries to cover many DynamicPPL features.
45+
46+
Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating
47+
a variable vector; observations passed as arguments, and as literals.
48+
"""
3949
@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
4050
@assert length(x) == length(y)
4151
m ~ truncated(Normal(); lower=0)
@@ -50,6 +60,13 @@ end
5060
return (; m=m, means=means, stds=stds)
5161
end
5262

63+
"""
64+
A model that loops over two vectors of univariate normals of length `num_dims`.
65+
66+
The second variable, `o`, is meant to be conditioned on after model instantiation.
67+
68+
See `multivariate` for a version that uses `product_distribution` rather than loops.
69+
"""
5370
@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
5471
a = TV(undef, num_dims)
5572
o = TV(undef, num_dims)
@@ -63,6 +80,13 @@ end
6380
return (; a=a)
6481
end
6582

83+
"""
84+
A model with two multivariate normal distributed variables of dimension `num_dims`.
85+
86+
The second variable, `o`, is meant to be conditioned on after model instantiation.
87+
88+
See `loop_univariate` for a version that uses loops rather than `product_distribution`.
89+
"""
6690
@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
6791
a = TV(undef, num_dims)
6892
o = TV(undef, num_dims)
@@ -72,17 +96,29 @@ end
7296
return (; a=a)
7397
end
7498

99+
"""
100+
A submodel for `parent`. Not exported.
101+
"""
75102
@model function sub()
76103
x ~ Normal()
77104
return x
78105
end
79106

80-
@model function parent(y)
107+
"""
108+
Like simple_assume_observe, but with a submodel for the assumed random variable.
109+
"""
110+
@model function parent(obs)
81111
x ~ to_submodel(sub())
82-
y ~ Normal(x, 1)
112+
obs ~ Normal(x, 1)
83113
return (; x=x)
84114
end
85115

116+
"""
117+
A model with random variables that have changing support.
118+
119+
Includes both variables the dimension of which depends on other variables, and variables
120+
the support of which changes under linking.
121+
"""
86122
@model function dynamic(::Type{T}=Vector{Float64}) where {T}
87123
eta ~ truncated(Normal(); lower=0.0)
88124
mat1 ~ LKJCholesky(4, eta)
@@ -93,6 +129,9 @@ end
93129
return (; eta=eta, mat1=mat1, mat2=mat2, vec=vec)
94130
end
95131

132+
"""
133+
A simple Linear Discriminant Analysis model.
134+
"""
96135
@model function lda(K, d, w)
97136
V = length(unique(w))
98137
D = length(unique(d))

0 commit comments

Comments
 (0)