Skip to content

Commit 923105e

Browse files
committed
Add model dimension to benchmark table
1 parent 5c35238 commit 923105e

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

benchmarks/benchmarks.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Pkg
22
# To ensure we benchmark the local version of DynamicPPL, dev the folder above.
33
Pkg.develop(; path=joinpath(@__DIR__, ".."))
44

5-
using DynamicPPLBenchmarks: Models, make_suite
5+
using DynamicPPLBenchmarks: Models, make_suite, model_dimension
66
using BenchmarkTools: @benchmark, median, run
77
using PrettyTables: PrettyTables, ft_printf
88
using Random: seed!
@@ -63,7 +63,7 @@ reference_time = begin
6363
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
6464
end
6565

66-
results_table = Tuple{String,String,String,Bool,Float64,Float64}[]
66+
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
6767

6868
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
6969
suite = make_suite(model, varinfo_choice, adbackend, islinked)
@@ -76,6 +76,7 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
7676
results_table,
7777
(
7878
model_name,
79+
model_dimension(model, islinked),
7980
string(adbackend),
8081
string(varinfo_choice),
8182
islinked,
@@ -88,6 +89,7 @@ end
8889
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
8990
header = [
9091
"Model",
92+
"Dimension",
9193
"AD Backend",
9294
"VarInfo Type",
9395
"Linked",

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,21 @@ using ReverseDiff: ReverseDiff
1313
include("./Models.jl")
1414
using .Models: Models
1515

16-
export Models, make_suite
16+
export Models, make_suite, model_dimension
17+
18+
"""
19+
model_dimension(model, islinked)
20+
21+
Return the dimension of `model`, accounting for linking, if any.
22+
"""
23+
function model_dimension(model, islinked)
24+
vi = VarInfo()
25+
model(vi)
26+
if islinked
27+
vi = DynamicPPL.link(vi, model)
28+
end
29+
return length(vi[:])
30+
end
1731

1832
# Utility functions for representing AD backends using symbols.
1933
# Copied from TuringBenchmarking.jl.

0 commit comments

Comments
 (0)