|
1 |
| -using DynamicPPL: @model |
2 |
| -using DynamicPPLBenchmarks: make_suite |
3 |
| -using BenchmarkTools: median, run |
4 |
| -using Distributions: Normal, Beta, Bernoulli |
5 |
| -using PrettyTables: pretty_table, PrettyTables |
| 1 | +using DynamicPPLBenchmarks: Models, make_suite |
| 2 | +using BenchmarkTools: @benchmark, median, run |
| 3 | +using PrettyTables: PrettyTables, ft_printf |
| 4 | +using Random: seed! |
6 | 5 |
|
7 |
| -# Define models |
8 |
| -@model function demo1(x) |
9 |
| - m ~ Normal() |
10 |
| - x ~ Normal(m, 1) |
11 |
| - return (m=m, x=x) |
12 |
| -end |
| 6 | +seed!(23) |
13 | 7 |
|
14 |
| -@model function demo2(y) |
15 |
| - p ~ Beta(1, 1) |
16 |
| - N = length(y) |
17 |
| - for n in 1:N |
18 |
| - y[n] ~ Bernoulli(p) |
19 |
| - end |
20 |
| - return (; p) |
| 8 | +# Create DynamicPPL.Model instances to run benchmarks on. |
| 9 | +smorgasbord_instance = Models.smorgasbord(randn(100), randn(100)) |
| 10 | +loop_univariate1k, multivariate1k = begin |
| 11 | + data_1k = randn(1_000) |
| 12 | + loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k) |
| 13 | + multi = Models.multivariate(length(data_1k)) | (; o=data_1k) |
| 14 | + loop, multi |
| 15 | +end |
| 16 | +loop_univariate10k, multivariate10k = begin |
| 17 | + data_10k = randn(10_000) |
| 18 | + loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k) |
| 19 | + multi = Models.multivariate(length(data_10k)) | (; o=data_10k) |
| 20 | + loop, multi |
| 21 | +end |
| 22 | +lda_instance = begin |
| 23 | + w = [1, 2, 3, 2, 1, 1] |
| 24 | + d = [1, 1, 1, 2, 2, 2] |
| 25 | + Models.lda(2, d, w) |
21 | 26 | end
|
22 |
| - |
23 |
| -demo1_data = randn() |
24 |
| -demo2_data = rand(Bool, 10) |
25 |
| - |
26 |
| -# Create model instances with the data |
27 |
| -demo1_instance = demo1(demo1_data) |
28 |
| -demo2_instance = demo2(demo2_data) |
29 | 27 |
|
30 | 28 | # Specify the combinations to test:
|
31 |
| -# (Model Name, model instance, VarInfo choice, AD backend) |
| 29 | +# (Model Name, model instance, VarInfo choice, AD backend, linked) |
32 | 30 | chosen_combinations = [
|
33 |
| - ("Demo1", demo1_instance, :typed, :forwarddiff), |
34 |
| - ("Demo1", demo1_instance, :simple_namedtuple, :zygote), |
35 |
| - ("Demo2", demo2_instance, :untyped, :reversediff), |
36 |
| - ("Demo2", demo2_instance, :simple_dict, :forwarddiff), |
| 31 | + ( |
| 32 | + "Simple assume observe", |
| 33 | + Models.simple_assume_observe(randn()), |
| 34 | + :typed, |
| 35 | + :forwarddiff, |
| 36 | + false, |
| 37 | + ), |
| 38 | + ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), |
| 39 | + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), |
| 40 | + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), |
| 41 | + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), |
| 42 | + ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), |
| 43 | + # TODO(mhauru) Add Mooncake once TuringBenchmarking.jl supports it. Consider changing |
| 44 | + # all the below :reversediffs to :mooncakes too. |
| 45 | + #("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), |
| 46 | + ("Loop univariate 1k", loop_univariate1k, :typed, :reversediff, true), |
| 47 | + ("Multivariate 1k", multivariate1k, :typed, :reversediff, true), |
| 48 | + ("Loop univariate 10k", loop_univariate10k, :typed, :reversediff, true), |
| 49 | + ("Multivariate 10k", multivariate10k, :typed, :reversediff, true), |
| 50 | + ("Dynamic", Models.dynamic(), :typed, :reversediff, true), |
| 51 | + ("Submodel", Models.parent(randn()), :typed, :reversediff, true), |
| 52 | + ("LDA", lda_instance, :typed, :reversediff, true), |
37 | 53 | ]
|
38 | 54 |
|
39 |
| -results_table = Tuple{String,String,String,Float64,Float64}[] |
| 55 | +# Time running a model-like function that does not use DynamicPPL, as a reference point. |
| 56 | +# Eval timings will be relative to this. |
| 57 | +reference_time = begin |
| 58 | + obs = randn() |
| 59 | + median(@benchmark Models.simple_assume_observe_non_model(obs)).time |
| 60 | +end |
| 61 | + |
| 62 | +results_table = Tuple{String,String,String,Bool,Float64,Float64}[] |
40 | 63 |
|
41 |
| -for (model_name, model, varinfo_choice, adbackend) in chosen_combinations |
| 64 | +for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations |
42 | 65 | suite = make_suite(model, varinfo_choice, adbackend)
|
43 | 66 | results = run(suite)
|
| 67 | + result_key = islinked ? "linked" : "standard" |
44 | 68 |
|
45 |
| - eval_time = median(results["AD_Benchmarking"]["evaluation"]["standard"]).time |
| 69 | + eval_time = median(results["evaluation"][result_key]).time |
| 70 | + relative_eval_time = eval_time / reference_time |
46 | 71 |
|
47 |
| - grad_group = results["AD_Benchmarking"]["gradient"] |
| 72 | + grad_group = results["gradient"] |
48 | 73 | if isempty(grad_group)
|
49 |
| - ad_eval_time = NaN |
| 74 | + relative_ad_eval_time = NaN |
50 | 75 | else
|
51 | 76 | grad_backend_key = first(keys(grad_group))
|
52 |
| - ad_eval_time = median(grad_group[grad_backend_key]["standard"]).time |
| 77 | + ad_eval_time = median(grad_group[grad_backend_key][result_key]).time |
| 78 | + relative_ad_eval_time = ad_eval_time / eval_time |
53 | 79 | end
|
54 | 80 |
|
55 | 81 | push!(
|
56 | 82 | results_table,
|
57 |
| - (model_name, string(adbackend), string(varinfo_choice), eval_time, ad_eval_time), |
| 83 | + ( |
| 84 | + model_name, |
| 85 | + string(adbackend), |
| 86 | + string(varinfo_choice), |
| 87 | + islinked, |
| 88 | + relative_eval_time, |
| 89 | + relative_ad_eval_time, |
| 90 | + ), |
58 | 91 | )
|
59 | 92 | end
|
60 | 93 |
|
61 | 94 | table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
|
62 | 95 | header = [
|
63 |
| - "Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)" |
| 96 | + "Model", |
| 97 | + "AD Backend", |
| 98 | + "VarInfo Type", |
| 99 | + "Linked", |
| 100 | + "Eval Time / Ref Time", |
| 101 | + "AD Time / Eval Time", |
64 | 102 | ]
|
65 |
| -pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown) |
| 103 | +PrettyTables.pretty_table( |
| 104 | + table_matrix; |
| 105 | + header=header, |
| 106 | + tf=PrettyTables.tf_markdown, |
| 107 | + formatters=ft_printf("%.1f", [5, 6]), |
| 108 | +) |
0 commit comments