Skip to content

Commit 4a02088

Browse files
committed
Choose whether to show linked or unlinked benchmark times
1 parent ee9a81f commit 4a02088

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

benchmarks/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ version = "0.1.0"
66
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1011
TuringBenchmarking = "0db1332d-5c25-4deb-809f-459bc696f94f"

benchmarks/benchmarks.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,30 @@ lda_instance = begin
2626
end
2727

2828
# Specify the combinations to test:
29-
# (Model Name, model instance, VarInfo choice, AD backend)
29+
# (Model Name, model instance, VarInfo choice, AD backend, linked)
3030
chosen_combinations = [
31-
("Simple assume observe", Models.simple_assume_observe(randn()), :typed, :forwarddiff),
32-
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff),
33-
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff),
34-
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff),
35-
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff),
36-
("Smorgasbord", smorgasbord_instance, :typed, :reversediff),
31+
(
32+
"Simple assume observe",
33+
Models.simple_assume_observe(randn()),
34+
:typed,
35+
:forwarddiff,
36+
false,
37+
),
38+
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
39+
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
40+
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
41+
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
42+
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
3743
# TODO(mhauru) Add Mooncake once TuringBenchmarking.jl supports it. Consider changing
3844
# all the below :reversediffs to :mooncakes too.
39-
#("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
40-
("Loop univariate 1k", loop_univariate1k, :typed, :reversediff),
41-
("Multivariate 1k", multivariate1k, :typed, :reversediff),
42-
("Loop univariate 10k", loop_univariate10k, :typed, :reversediff),
43-
("Multivariate 10k", multivariate10k, :typed, :reversediff),
44-
("Dynamic", Models.dynamic(), :typed, :reversediff),
45-
("Submodel", Models.parent(randn()), :typed, :reversediff),
46-
("LDA", lda_instance, :typed, :reversediff),
45+
#("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
46+
("Loop univariate 1k", loop_univariate1k, :typed, :reversediff, true),
47+
("Multivariate 1k", multivariate1k, :typed, :reversediff, true),
48+
("Loop univariate 10k", loop_univariate10k, :typed, :reversediff, true),
49+
("Multivariate 10k", multivariate10k, :typed, :reversediff, true),
50+
("Dynamic", Models.dynamic(), :typed, :reversediff, true),
51+
("Submodel", Models.parent(randn()), :typed, :reversediff, true),
52+
("LDA", lda_instance, :typed, :reversediff, true),
4753
]
4854

4955
# Time running a model-like function that does not use DynamicPPL, as a reference point.
@@ -53,21 +59,22 @@ reference_time = begin
5359
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
5460
end
5561

56-
results_table = Tuple{String,String,String,Float64,Float64}[]
62+
results_table = Tuple{String,String,String,Bool,Float64,Float64}[]
5763

58-
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
64+
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
5965
suite = make_suite(model, varinfo_choice, adbackend)
6066
results = run(suite)
67+
result_key = islinked ? "linked" : "standard"
6168

62-
eval_time = median(results["evaluation"]["standard"]).time
69+
eval_time = median(results["evaluation"][result_key]).time
6370
relative_eval_time = eval_time / reference_time
6471

6572
grad_group = results["gradient"]
6673
if isempty(grad_group)
6774
relative_ad_eval_time = NaN
6875
else
6976
grad_backend_key = first(keys(grad_group))
70-
ad_eval_time = median(grad_group[grad_backend_key]["standard"]).time
77+
ad_eval_time = median(grad_group[grad_backend_key][result_key]).time
7178
relative_ad_eval_time = ad_eval_time / eval_time
7279
end
7380

@@ -77,6 +84,7 @@ for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
7784
model_name,
7885
string(adbackend),
7986
string(varinfo_choice),
87+
islinked,
8088
relative_eval_time,
8189
relative_ad_eval_time,
8290
),
@@ -88,12 +96,13 @@ header = [
8896
"Model",
8997
"AD Backend",
9098
"VarInfo Type",
99+
"Linked",
91100
"Evaluation Time / Reference Time",
92101
"AD Time / Eval Time",
93102
]
94103
PrettyTables.pretty_table(
95104
table_matrix;
96105
header=header,
97106
tf=PrettyTables.tf_markdown,
98-
formatters=ft_printf("%.1f", [4, 5]),
107+
formatters=ft_printf("%.1f", [5, 6]),
99108
)

benchmarks/src/Models.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@ observed (this is used for constructing SimpleVarInfos).
66
"""
77
module Models
88

9-
using DynamicPPL: @model, to_submodel
109
using Distributions:
1110
Categorical,
1211
Dirichlet,
1312
Exponential,
1413
Gamma,
1514
LKJCholesky,
16-
MatrixBeta,
15+
InverseWishart,
1716
Normal,
1817
logpdf,
1918
product_distribution,
2019
truncated
20+
using DynamicPPL: @model, to_submodel
21+
using LinearAlgebra: cholesky
2122

2223
export simple_assume_observe_non_model,
2324
simple_assume_observe, smorgasbord, loop_univariate, multivariate, parent, dynamic, lda
@@ -114,19 +115,14 @@ Like simple_assume_observe, but with a submodel for the assumed random variable.
114115
end
115116

116117
"""
117-
A model with random variables that have changing support.
118-
119-
Includes both variables the dimension of which depends on other variables, and variables
120-
the support of which changes under linking.
118+
A model with random variables that have changing support under linking, or otherwise
119+
complicated bijectors.
121120
"""
122121
@model function dynamic(::Type{T}=Vector{Float64}) where {T}
123-
eta ~ truncated(Normal(); lower=0.0)
122+
eta ~ truncated(Normal(); lower=0.0, upper=0.1)
124123
mat1 ~ LKJCholesky(4, eta)
125-
mat2 ~ MatrixBeta(5, 6.0, 8.0)
126-
dim = eta > 0.2 ? 2 : 3
127-
vec = T(undef, dim)
128-
vec .~ truncated(Exponential(0.5); lower=0.0, upper=1.0)
129-
return (; eta=eta, mat1=mat1, mat2=mat2, vec=vec)
124+
mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0]))
125+
return (; eta=eta, mat1=mat1, mat2=mat2)
130126
end
131127

132128
"""

0 commit comments

Comments
 (0)