Skip to content

Commit 3b5e448

Browse files
updated benchmarking setup
1 parent 6f255d1 commit 3b5e448

17 files changed

+98
-122173
lines changed

benchmarks/Project.toml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,8 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001"
33
version = "0.1.0"
44

55
[deps]
6-
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
76
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
9-
DiffUtils = "8294860b-85a6-42f8-8c35-d911f667b5f6"
107
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
11-
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
128
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
13-
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
14-
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
15-
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
16-
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
17-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
189
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
19-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20-
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
21-
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"
10+
TuringBenchmarking = "0db1332d-5c25-4deb-809f-459bc696f94f"

benchmarks/benchmark_body.jmd

Lines changed: 0 additions & 72 deletions
This file was deleted.

benchmarks/benchmarks.jl

Lines changed: 60 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
1-
using BenchmarkTools
21
using DynamicPPL
2+
using DynamicPPLBenchmarks
3+
using BenchmarkTools
4+
using TuringBenchmarking
35
using Distributions
4-
using DynamicPPLBenchmarks: time_model_def, make_suite
56
using PrettyTables
6-
using Dates
7-
using LibGit2
8-
9-
const RESULTS_DIR = "results"
10-
const BENCHMARK_NAME = let
11-
repo = try
12-
LibGit2.GitRepo(joinpath(pkgdir(DynamicPPL), ".."))
13-
catch
14-
nothing
15-
end
16-
isnothing(repo) ? "benchmarks_$(Dates.format(now(), "yyyy-mm-dd_HH-MM-SS"))" :
17-
"$(LibGit2.headname(repo))_$(string(LibGit2.GitHash(LibGit2.peel(LibGit2.GitCommit, LibGit2.head(repo))))[1:6])"
18-
end
19-
20-
mkpath(joinpath(RESULTS_DIR, BENCHMARK_NAME))
217

8+
# Define models
229
@model function demo1(x)
2310
m ~ Normal()
2411
x ~ Normal(m, 1)
@@ -34,61 +21,67 @@ end
3421
return (; p)
3522
end
3623

37-
models = [
38-
(name = "demo1", model = demo1, data = (1.0,)),
39-
(name = "demo2", model = demo2, data = (rand(0:1, 10),))
40-
]
24+
demo1_data = randn()
25+
demo2_data = rand(Bool, 10)
4126

42-
results = []
43-
for (model_name, model_def, data) in models
44-
println(">> Running benchmarks for model: $model_name")
45-
m = time_model_def(model_def, data...)
46-
println()
47-
suite = make_suite(m)
48-
bench_results = run(suite, seconds=10)
49-
50-
output_path = joinpath(RESULTS_DIR, BENCHMARK_NAME, "$(model_name)_benchmarks.json")
51-
BenchmarkTools.save(output_path, bench_results)
52-
53-
for (eval_type, trial) in bench_results
54-
push!(results, (
55-
Model = model_name,
56-
Evaluation = eval_type,
57-
Time = minimum(trial).time,
58-
Memory = trial.memory,
59-
Allocations = trial.allocs,
60-
Samples = length(trial.times)
61-
))
62-
end
63-
end
27+
# Create model instances with the data
28+
demo1_instance = demo1(demo1_data)
29+
demo2_instance = demo2(demo2_data)
6430

65-
formatted = map(results) do r
66-
(Model = r.Model,
67-
Evaluation = replace(r.Evaluation, "_" => " "),
68-
Time = BenchmarkTools.prettytime(r.Time),
69-
Memory = BenchmarkTools.prettymemory(r.Memory),
70-
Allocations = string(r.Allocations),
71-
Samples = string(r.Samples))
72-
end
31+
# Define available AD backends
32+
available_ad_backends = Dict(
33+
:forwarddiff => :forwarddiff,
34+
:reversediff => :reversediff,
35+
:zygote => :zygote
36+
)
7337

74-
md_output = """
75-
## DynamicPPL Benchmark Results ($BENCHMARK_NAME)
38+
# Define available VarInfo types.
39+
# Each entry is (Name, function to produce the VarInfo)
40+
available_varinfo_types = Dict(
41+
:untyped => ("UntypedVarInfo", VarInfo),
42+
:typed => ("TypedVarInfo", m -> VarInfo(m)),
43+
:simple_namedtuple => ("SimpleVarInfo (NamedTuple)", m -> SimpleVarInfo{Float64}(m())),
44+
:simple_dict => ("SimpleVarInfo (Dict)", m -> begin
45+
retvals = m()
46+
varnames = map(keys(retvals)) do k
47+
VarName{k}()
48+
end
49+
SimpleVarInfo{Float64}(Dict(zip(varnames, values(retvals))))
50+
end)
51+
)
7652

77-
### Execution Environment
78-
- Julia version: $(VERSION)
79-
- DynamicPPL version: $(pkgversion(DynamicPPL))
80-
- Benchmark date: $(now())
53+
# Specify the combinations to test:
54+
# (Model Name, model instance, VarInfo choice, AD backend)
55+
chosen_combinations = [
56+
("Demo1", demo1_instance, :typed, :forwarddiff),
57+
("Demo1", demo1_instance, :simple_namedtuple, :zygote),
58+
("Demo2", demo2_instance, :untyped, :reversediff),
59+
("Demo2", demo2_instance, :simple_dict, :forwarddiff)
60+
]
8161

82-
$(pretty_table(String, formatted,
83-
tf = tf_markdown,
84-
header = ["Model", "Evaluation Type", "Time", "Memory", "Allocs", "Samples"],
85-
alignment = [:l, :l, :r, :r, :r, :r]
86-
))
87-
"""
62+
# Store results as tuples: (Model, AD Backend, VarInfo Type, Eval Time, AD Eval Time)
63+
results_table = Tuple{String, String, String, Float64, Float64}[]
8864

89-
println(md_output)
90-
open(joinpath(RESULTS_DIR, BENCHMARK_NAME, "REPORT.md"), "w") do io
91-
write(io, md_output)
65+
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
66+
suite = make_suite(model, varinfo_choice, adbackend)
67+
results = run(suite)
68+
eval_time = median(results["evaluation"]).time
69+
ad_eval_time = median(results["AD_Benchmarking"]["evaluation"]["standard"]).time
70+
push!(results_table, (model_name, string(adbackend), string(varinfo_choice), eval_time, ad_eval_time))
71+
end
72+
73+
# Convert results to a 2D array for PrettyTables
74+
function to_matrix(tuples::Vector{<:NTuple{5,Any}})
75+
n = length(tuples)
76+
data = Array{Any}(undef, n, 5)
77+
for i in 1:n
78+
for j in 1:5
79+
data[i, j] = tuples[i][j]
80+
end
81+
end
82+
return data
9283
end
9384

94-
println("Benchmark results saved to: $RESULTS_DIR/$BENCHMARK_NAME")
85+
table_matrix = to_matrix(results_table)
86+
header = ["Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)"]
87+
pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown)

0 commit comments

Comments
 (0)