Skip to content

Commit 640aa45

Browse files
applied suggested changes
1 parent 1e61025 commit 640aa45

File tree

2 files changed

+48
-70
lines changed

2 files changed

+48
-70
lines changed

benchmarks/benchmarks.jl

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
using DynamicPPL
2-
using DynamicPPLBenchmarks
3-
using BenchmarkTools
4-
using TuringBenchmarking
5-
using Distributions
6-
using PrettyTables
1+
using DynamicPPL: @model
2+
using DynamicPPLBenchmarks: make_suite
3+
using BenchmarkTools: median, run
4+
using Distributions: Normal, Beta, Bernoulli
5+
using PrettyTables: pretty_table, PrettyTables
76

87
# Define models
98
@model function demo1(x)
109
m ~ Normal()
1110
x ~ Normal(m, 1)
12-
return (m = m, x = x)
11+
return (m=m, x=x)
1312
end
1413

1514
@model function demo2(y)
@@ -28,60 +27,39 @@ demo2_data = rand(Bool, 10)
2827
demo1_instance = demo1(demo1_data)
2928
demo2_instance = demo2(demo2_data)
3029

31-
# Define available AD backends
32-
available_ad_backends = Dict(
33-
:forwarddiff => :forwarddiff,
34-
:reversediff => :reversediff,
35-
:zygote => :zygote
36-
)
37-
38-
# Define available VarInfo types.
39-
# Each entry is (Name, function to produce the VarInfo)
40-
available_varinfo_types = Dict(
41-
:untyped => ("UntypedVarInfo", VarInfo),
42-
:typed => ("TypedVarInfo", m -> VarInfo(m)),
43-
:simple_namedtuple => ("SimpleVarInfo (NamedTuple)", m -> SimpleVarInfo{Float64}(m())),
44-
:simple_dict => ("SimpleVarInfo (Dict)", m -> begin
45-
retvals = m()
46-
varnames = map(keys(retvals)) do k
47-
VarName{k}()
48-
end
49-
SimpleVarInfo{Float64}(Dict(zip(varnames, values(retvals))))
50-
end)
51-
)
52-
5330
# Specify the combinations to test:
5431
# (Model Name, model instance, VarInfo choice, AD backend)
5532
chosen_combinations = [
56-
("Demo1", demo1_instance, :typed, :forwarddiff),
33+
("Demo1", demo1_instance, :typed, :forwarddiff),
5734
("Demo1", demo1_instance, :simple_namedtuple, :zygote),
58-
("Demo2", demo2_instance, :untyped, :reversediff),
59-
("Demo2", demo2_instance, :simple_dict, :forwarddiff)
35+
("Demo2", demo2_instance, :untyped, :reversediff),
36+
("Demo2", demo2_instance, :simple_dict, :forwarddiff),
6037
]
6138

62-
# Store results as tuples: (Model, AD Backend, VarInfo Type, Eval Time, AD Eval Time)
63-
results_table = Tuple{String, String, String, Float64, Float64}[]
39+
results_table = Tuple{String,String,String,Float64,Float64}[]
6440

6541
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
6642
suite = make_suite(model, varinfo_choice, adbackend)
6743
results = run(suite)
68-
eval_time = median(results["evaluation"]).time
69-
ad_eval_time = median(results["AD_Benchmarking"]["evaluation"]["standard"]).time
70-
push!(results_table, (model_name, string(adbackend), string(varinfo_choice), eval_time, ad_eval_time))
71-
end
7244

73-
# Convert results to a 2D array for PrettyTables
74-
function to_matrix(tuples::Vector{<:NTuple{5,Any}})
75-
n = length(tuples)
76-
data = Array{Any}(undef, n, 5)
77-
for i in 1:n
78-
for j in 1:5
79-
data[i, j] = tuples[i][j]
80-
end
45+
eval_time = median(results["AD_Benchmarking"]["evaluation"]["standard"]).time
46+
47+
grad_group = results["AD_Benchmarking"]["gradient"]
48+
if isempty(grad_group)
49+
ad_eval_time = NaN
50+
else
51+
grad_backend_key = first(keys(grad_group))
52+
ad_eval_time = median(grad_group[grad_backend_key]["standard"]).time
8153
end
82-
return data
54+
55+
push!(
56+
results_table,
57+
(model_name, string(adbackend), string(varinfo_choice), eval_time, ad_eval_time),
58+
)
8359
end
8460

85-
table_matrix = to_matrix(results_table)
86-
header = ["Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)"]
61+
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
62+
header = [
63+
"Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)"
64+
]
8765
pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown)

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DynamicPPLBenchmarks
22

3-
using DynamicPPL
4-
using BenchmarkTools
3+
using DynamicPPL: VarInfo, SimpleVarInfo, VarName
4+
using BenchmarkTools: BenchmarkGroup
55
using TuringBenchmarking: make_turing_suite
66

77
export make_suite
@@ -13,40 +13,40 @@ Create a benchmark suite for `model` using the selected varinfo type and AD back
1313
Available varinfo choices:
1414
• `:untyped` → uses `VarInfo()`
1515
• `:typed` → uses `VarInfo(model)`
16-
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
16+
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(free_nt)`
1717
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
1818
1919
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
2020
"""
2121
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol)
2222
suite = BenchmarkGroup()
23-
context = DefaultContext()
24-
25-
# Create the chosen varinfo.
26-
vi = nothing
27-
if varinfo_choice == :untyped
28-
vi = VarInfo()
29-
model(vi)
23+
24+
vi = if varinfo_choice == :untyped
25+
v = VarInfo()
26+
model(v)
27+
v
3028
elseif varinfo_choice == :typed
31-
vi = VarInfo(model)
29+
VarInfo(model)
3230
elseif varinfo_choice == :simple_namedtuple
33-
vi = SimpleVarInfo{Float64}(model())
31+
free_nt = NamedTuple{(:m,)}(model()) # Extract only the free parameter(s)
32+
SimpleVarInfo{Float64}(free_nt)
3433
elseif varinfo_choice == :simple_dict
3534
retvals = model()
36-
vns = map(keys(retvals)) do k
37-
VarName{k}()
38-
end
39-
vi = SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
35+
vns = [VarName{k}() for k in keys(retvals)]
36+
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
4037
else
4138
error("Unknown varinfo choice: $varinfo_choice")
4239
end
4340

44-
# Add the evaluation benchmark.
45-
suite["evaluation"] = @benchmarkable $model($vi, $context)
46-
4741
# Add the AD benchmarking suite.
48-
suite["AD_Benchmarking"] = make_turing_suite(model; adbackends=[adbackend])
49-
42+
suite["AD_Benchmarking"] = make_turing_suite(
43+
model;
44+
adbackends=[adbackend],
45+
varinfo=vi,
46+
check_grads=true,
47+
error_on_failed_backend=true,
48+
)
49+
5050
return suite
5151
end
5252

0 commit comments

Comments
 (0)