Skip to content

Commit 3cd8d34

Browse files
authored
Improvements to benchmark outputs (#1146)
* print output * fix * reenable * add more lines to guide the eye * reorder table * print tgrad / trel as well * forgot this type
1 parent 8553e40 commit 3cd8d34

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

benchmarks/benchmarks.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,15 @@ function run(; to_json=false)
9898
}[]
9999

100100
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
101-
@info "Running benchmark for $model_name"
101+
@info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked"
102102
relative_eval_time, relative_ad_eval_time = try
103103
results = benchmark(model, varinfo_choice, adbackend, islinked)
104+
@info " t(eval) = $(results.primal_time)"
105+
@info " t(grad) = $(results.grad_time)"
104106
(results.primal_time / reference_time),
105107
(results.grad_time / results.primal_time)
106108
catch e
109+
@info "benchmark errored: $e"
107110
missing, missing
108111
end
109112
push!(
@@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String)
155158
all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases)))
156159
@info "$(length(all_testcases)) unique test cases found"
157160
sorted_testcases = sort(
158-
collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked))
161+
collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend))
159162
)
160163
results_table = Tuple{
161-
String,Int,String,String,Bool,String,String,String,String,String,String
164+
String,
165+
Int,
166+
String,
167+
String,
168+
Bool,
169+
String,
170+
String,
171+
String,
172+
String,
173+
String,
174+
String,
175+
String,
176+
String,
177+
String,
162178
}[]
179+
sublabels = ["base", "this PR", "speedup"]
163180
results_colnames = [
164181
[
165182
EmptyCells(5),
166183
MultiColumn(3, "t(eval) / t(ref)"),
167184
MultiColumn(3, "t(grad) / t(eval)"),
185+
MultiColumn(3, "t(grad) / t(ref)"),
168186
],
169-
[colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"],
187+
[colnames[1:5]..., sublabels..., sublabels..., sublabels...],
170188
]
171189
sprint_float(x::Float64) = @sprintf("%.2f", x)
172190
sprint_float(m::Missing) = "err"
@@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String)
183201
# Finally that lets us do this division safely
184202
speedup_eval = base_eval / head_eval
185203
speedup_grad = base_grad / head_grad
204+
# As well as this multiplication, which is t(grad) / t(ref)
205+
head_grad_vs_ref = head_grad * head_eval
206+
base_grad_vs_ref = base_grad * base_eval
207+
speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref
186208
push!(
187209
results_table,
188210
(
@@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String)
197219
sprint_float(base_grad),
198220
sprint_float(head_grad),
199221
sprint_float(speedup_grad),
222+
sprint_float(base_grad_vs_ref),
223+
sprint_float(head_grad_vs_ref),
224+
sprint_float(speedup_grad_vs_ref),
200225
),
201226
)
202227
end
@@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String)
212237
backend=:text,
213238
fit_table_in_display_horizontally=false,
214239
fit_table_in_display_vertically=false,
215-
table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true),
240+
table_format=TextTableFormat(;
241+
horizontal_line_at_merged_column_labels=true,
242+
horizontal_lines_at_data_rows=collect(3:3:length(results_table)),
243+
),
216244
)
217245
println("```")
218246
end

src/test_utils/ad.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
77
using DynamicPPL:
8-
Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link
8+
DynamicPPL,
9+
Model,
10+
LogDensityFunction,
11+
VarInfo,
12+
AbstractVarInfo,
13+
getlogjoint_internal,
14+
link
915
using LogDensityProblems: logdensity, logdensity_and_gradient
1016
using Random: AbstractRNG, default_rng
1117
using Statistics: median
@@ -298,7 +304,9 @@ function run_ad(
298304

299305
# Benchmark
300306
grad_time, primal_time = if benchmark
307+
logdensity(ldf, params) # Warm-up
301308
primal_benchmark = @be logdensity($ldf, $params)
309+
logdensity_and_gradient(ldf, params) # Warm-up
302310
grad_benchmark = @be logdensity_and_gradient($ldf, $params)
303311
median_primal = median(primal_benchmark).time
304312
median_grad = median(grad_benchmark).time

0 commit comments

Comments
 (0)