@@ -26,24 +26,30 @@ lda_instance = begin
26
26
end
27
27
28
28
# Specify the combinations to test:
29
- # (Model Name, model instance, VarInfo choice, AD backend)
29
+ # (Model Name, model instance, VarInfo choice, AD backend, linked )
30
30
chosen_combinations = [
31
- (" Simple assume observe" , Models. simple_assume_observe (randn ()), :typed , :forwarddiff ),
32
- (" Smorgasbord" , smorgasbord_instance, :typed , :forwarddiff ),
33
- (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , :forwarddiff ),
34
- (" Smorgasbord" , smorgasbord_instance, :untyped , :forwarddiff ),
35
- (" Smorgasbord" , smorgasbord_instance, :simple_dict , :forwarddiff ),
36
- (" Smorgasbord" , smorgasbord_instance, :typed , :reversediff ),
31
+ (
32
+ " Simple assume observe" ,
33
+ Models. simple_assume_observe (randn ()),
34
+ :typed ,
35
+ :forwarddiff ,
36
+ false ,
37
+ ),
38
+ (" Smorgasbord" , smorgasbord_instance, :typed , :forwarddiff , false ),
39
+ (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , :forwarddiff , true ),
40
+ (" Smorgasbord" , smorgasbord_instance, :untyped , :forwarddiff , true ),
41
+ (" Smorgasbord" , smorgasbord_instance, :simple_dict , :forwarddiff , true ),
42
+ (" Smorgasbord" , smorgasbord_instance, :typed , :reversediff , true ),
37
43
# TODO (mhauru) Add Mooncake once TuringBenchmarking.jl supports it. Consider changing
38
44
# all the below :reversediffs to :mooncakes too.
39
- # ("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
40
- (" Loop univariate 1k" , loop_univariate1k, :typed , :reversediff ),
41
- (" Multivariate 1k" , multivariate1k, :typed , :reversediff ),
42
- (" Loop univariate 10k" , loop_univariate10k, :typed , :reversediff ),
43
- (" Multivariate 10k" , multivariate10k, :typed , :reversediff ),
44
- (" Dynamic" , Models. dynamic (), :typed , :reversediff ),
45
- (" Submodel" , Models. parent (randn ()), :typed , :reversediff ),
46
- (" LDA" , lda_instance, :typed , :reversediff ),
45
+ # ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true ),
46
+ (" Loop univariate 1k" , loop_univariate1k, :typed , :reversediff , true ),
47
+ (" Multivariate 1k" , multivariate1k, :typed , :reversediff , true ),
48
+ (" Loop univariate 10k" , loop_univariate10k, :typed , :reversediff , true ),
49
+ (" Multivariate 10k" , multivariate10k, :typed , :reversediff , true ),
50
+ (" Dynamic" , Models. dynamic (), :typed , :reversediff , true ),
51
+ (" Submodel" , Models. parent (randn ()), :typed , :reversediff , true ),
52
+ (" LDA" , lda_instance, :typed , :reversediff , true ),
47
53
]
48
54
49
55
# Time running a model-like function that does not use DynamicPPL, as a reference point.
@@ -53,21 +59,22 @@ reference_time = begin
53
59
median (@benchmark Models. simple_assume_observe_non_model (obs)). time
54
60
end
55
61
56
- results_table = Tuple{String,String,String,Float64,Float64}[]
62
+ results_table = Tuple{String,String,String,Bool, Float64,Float64}[]
57
63
58
- for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
64
+ for (model_name, model, varinfo_choice, adbackend, islinked ) in chosen_combinations
59
65
suite = make_suite (model, varinfo_choice, adbackend)
60
66
results = run (suite)
67
+ result_key = islinked ? " linked" : " standard"
61
68
62
- eval_time = median (results[" evaluation" ][" standard " ]). time
69
+ eval_time = median (results[" evaluation" ][result_key ]). time
63
70
relative_eval_time = eval_time / reference_time
64
71
65
72
grad_group = results[" gradient" ]
66
73
if isempty (grad_group)
67
74
relative_ad_eval_time = NaN
68
75
else
69
76
grad_backend_key = first (keys (grad_group))
70
- ad_eval_time = median (grad_group[grad_backend_key][" standard " ]). time
77
+ ad_eval_time = median (grad_group[grad_backend_key][result_key ]). time
71
78
relative_ad_eval_time = ad_eval_time / eval_time
72
79
end
73
80
@@ -77,6 +84,7 @@ for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
77
84
model_name,
78
85
string (adbackend),
79
86
string (varinfo_choice),
87
+ islinked,
80
88
relative_eval_time,
81
89
relative_ad_eval_time,
82
90
),
@@ -88,12 +96,13 @@ header = [
88
96
" Model" ,
89
97
" AD Backend" ,
90
98
" VarInfo Type" ,
99
+ " Linked" ,
91
100
" Evaluation Time / Reference Time" ,
92
101
" AD Time / Eval Time" ,
93
102
]
94
103
PrettyTables. pretty_table (
95
104
table_matrix;
96
105
header= header,
97
106
tf= PrettyTables. tf_markdown,
98
- formatters= ft_printf (" %.1f" , [4 , 5 ]),
107
+ formatters= ft_printf (" %.1f" , [5 , 6 ]),
99
108
)
0 commit comments