1
- using DynamicPPL
2
- using DynamicPPLBenchmarks
3
- using BenchmarkTools
4
- using TuringBenchmarking
5
- using Distributions
6
- using PrettyTables
1
+ using DynamicPPL: @model
2
+ using DynamicPPLBenchmarks: make_suite
3
+ using BenchmarkTools: median, run
4
+ using Distributions: Normal, Beta, Bernoulli
5
+ using PrettyTables: pretty_table, PrettyTables
7
6
8
7
# Define models
9
8
@model function demo1 (x)
10
9
m ~ Normal ()
11
10
x ~ Normal (m, 1 )
12
- return (m = m, x = x)
11
+ return (m= m, x= x)
13
12
end
14
13
15
14
@model function demo2 (y)
@@ -28,60 +27,39 @@ demo2_data = rand(Bool, 10)
28
27
demo1_instance = demo1 (demo1_data)
29
28
demo2_instance = demo2 (demo2_data)
30
29
31
- # Define available AD backends
32
- available_ad_backends = Dict (
33
- :forwarddiff => :forwarddiff ,
34
- :reversediff => :reversediff ,
35
- :zygote => :zygote
36
- )
37
-
38
- # Define available VarInfo types.
39
- # Each entry is (Name, function to produce the VarInfo)
40
- available_varinfo_types = Dict (
41
- :untyped => (" UntypedVarInfo" , VarInfo),
42
- :typed => (" TypedVarInfo" , m -> VarInfo (m)),
43
- :simple_namedtuple => (" SimpleVarInfo (NamedTuple)" , m -> SimpleVarInfo {Float64} (m ())),
44
- :simple_dict => (" SimpleVarInfo (Dict)" , m -> begin
45
- retvals = m ()
46
- varnames = map (keys (retvals)) do k
47
- VarName {k} ()
48
- end
49
- SimpleVarInfo {Float64} (Dict (zip (varnames, values (retvals))))
50
- end )
51
- )
52
-
53
30
# Specify the combinations to test:
54
31
# (Model Name, model instance, VarInfo choice, AD backend)
55
32
chosen_combinations = [
56
- (" Demo1" , demo1_instance, :typed , :forwarddiff ),
33
+ (" Demo1" , demo1_instance, :typed , :forwarddiff ),
57
34
(" Demo1" , demo1_instance, :simple_namedtuple , :zygote ),
58
- (" Demo2" , demo2_instance, :untyped , :reversediff ),
59
- (" Demo2" , demo2_instance, :simple_dict , :forwarddiff )
35
+ (" Demo2" , demo2_instance, :untyped , :reversediff ),
36
+ (" Demo2" , demo2_instance, :simple_dict , :forwarddiff ),
60
37
]
61
38
62
- # Store results as tuples: (Model, AD Backend, VarInfo Type, Eval Time, AD Eval Time)
63
- results_table = Tuple{String, String, String, Float64, Float64}[]
39
+ results_table = Tuple{String,String,String,Float64,Float64}[]
64
40
65
41
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
66
42
suite = make_suite (model, varinfo_choice, adbackend)
67
43
results = run (suite)
68
- eval_time = median (results[" evaluation" ]). time
69
- ad_eval_time = median (results[" AD_Benchmarking" ][" evaluation" ][" standard" ]). time
70
- push! (results_table, (model_name, string (adbackend), string (varinfo_choice), eval_time, ad_eval_time))
71
- end
72
44
73
- # Convert results to a 2D array for PrettyTables
74
- function to_matrix (tuples :: Vector{<:NTuple{5,Any}} )
75
- n = length (tuples)
76
- data = Array {Any} (undef, n, 5 )
77
- for i in 1 : n
78
- for j in 1 : 5
79
- data[i, j] = tuples[i][j]
80
- end
45
+ eval_time = median (results[ " AD_Benchmarking " ][ " evaluation " ][ " standard " ]) . time
46
+
47
+ grad_group = results[ " AD_Benchmarking " ][ " gradient " ]
48
+ if isempty (grad_group )
49
+ ad_eval_time = NaN
50
+ else
51
+ grad_backend_key = first ( keys (grad_group))
52
+ ad_eval_time = median (grad_group[grad_backend_key][ " standard " ]) . time
81
53
end
82
- return data
54
+
55
+ push! (
56
+ results_table,
57
+ (model_name, string (adbackend), string (varinfo_choice), eval_time, ad_eval_time),
58
+ )
83
59
end
84
60
85
- table_matrix = to_matrix (results_table)
86
- header = [" Model" , " AD Backend" , " VarInfo Type" , " Evaluation Time (ns)" , " AD Eval Time (ns)" ]
61
+ table_matrix = hcat (Iterators. map (collect, zip (results_table... ))... )
62
+ header = [
63
+ " Model" , " AD Backend" , " VarInfo Type" , " Evaluation Time (ns)" , " AD Eval Time (ns)"
64
+ ]
87
65
pretty_table (table_matrix; header= header, tf= PrettyTables. tf_markdown)
0 commit comments