Skip to content

Commit 0f7c924

Browse files
committed
Update models to benchmark plus small style changes
1 parent 1d1b11e commit 0f7c924

File tree

3 files changed

+169
-39
lines changed

3 files changed

+169
-39
lines changed

benchmarks/benchmarks.jl

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,46 @@
1-
using DynamicPPL: @model
2-
using DynamicPPLBenchmarks: make_suite
1+
using DynamicPPLBenchmarks: Models, make_suite
32
using BenchmarkTools: median, run
4-
using Distributions: Normal, Beta, Bernoulli
5-
using PrettyTables: pretty_table, PrettyTables
3+
using PrettyTables: PrettyTables
4+
using Random: seed!
65

7-
# Define models
8-
@model function demo1(x)
9-
m ~ Normal()
10-
x ~ Normal(m, 1)
11-
return (m=m, x=x)
12-
end
6+
seed!(23)
137

14-
@model function demo2(y)
15-
p ~ Beta(1, 1)
16-
N = length(y)
17-
for n in 1:N
18-
y[n] ~ Bernoulli(p)
19-
end
20-
return (; p)
8+
smorgasbord_instance = Models.smorgasbord(randn(100), randn(100))
9+
loop_univariate1k, multivariate1k = begin
10+
data_1k = randn(1_000)
11+
loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k)
12+
multi = Models.multivariate(length(data_1k)) | (; o=data_1k)
13+
loop, multi
14+
end
15+
loop_univariate10k, multivariate10k = begin
16+
data_10k = randn(10_000)
17+
loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k)
18+
multi = Models.multivariate(length(data_10k)) | (; o=data_10k)
19+
loop, multi
20+
end
21+
lda_instance = begin
22+
w = [1, 2, 3, 2, 1, 1]
23+
d = [1, 1, 1, 2, 2, 2]
24+
Models.lda(2, d, w)
2125
end
22-
23-
demo1_data = randn()
24-
demo2_data = rand(Bool, 10)
25-
26-
# Create model instances with the data
27-
demo1_instance = demo1(demo1_data)
28-
demo2_instance = demo2(demo2_data)
2926

3027
# Specify the combinations to test:
3128
# (Model Name, model instance, VarInfo choice, AD backend)
3229
chosen_combinations = [
33-
("Demo1", demo1_instance, :typed, :forwarddiff),
34-
("Demo1", demo1_instance, :simple_namedtuple, :zygote),
35-
("Demo2", demo2_instance, :untyped, :reversediff),
36-
("Demo2", demo2_instance, :simple_dict, :forwarddiff),
30+
("Simple assume observe", Models.simple_assume_observe(randn()), :typed, :forwarddiff),
31+
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff),
32+
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff),
33+
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff),
34+
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff),
35+
("Smorgasbord", smorgasbord_instance, :typed, :reversediff),
36+
#("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
37+
("Loop univariate 1k", loop_univariate1k, :typed, :reversediff),
38+
("Multivariate 1k", multivariate1k, :typed, :reversediff),
39+
("Loop univariate 10k", loop_univariate10k, :typed, :reversediff),
40+
("Multivariate 10k", multivariate10k, :typed, :reversediff),
41+
("Dynamic", Models.dynamic(), :typed, :reversediff),
42+
("Submodel", Models.parent(randn()), :typed, :reversediff),
43+
("LDA", lda_instance, :typed, :reversediff),
3744
]
3845

3946
results_table = Tuple{String,String,String,Float64,Float64}[]
@@ -42,9 +49,9 @@ for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
4249
suite = make_suite(model, varinfo_choice, adbackend)
4350
results = run(suite)
4451

45-
eval_time = median(results["AD_Benchmarking"]["evaluation"]["standard"]).time
52+
eval_time = median(results["evaluation"]["standard"]).time
4653

47-
grad_group = results["AD_Benchmarking"]["gradient"]
54+
grad_group = results["gradient"]
4855
if isempty(grad_group)
4956
ad_eval_time = NaN
5057
else
@@ -62,4 +69,4 @@ table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
6269
header = [
6370
"Model", "AD Backend", "VarInfo Type", "Evaluation Time (ns)", "AD Eval Time (ns)"
6471
]
65-
pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown)
72+
PrettyTables.pretty_table(table_matrix; header=header, tf=PrettyTables.tf_markdown)

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using DynamicPPL: VarInfo, SimpleVarInfo, VarName
44
using BenchmarkTools: BenchmarkGroup
55
using TuringBenchmarking: make_turing_suite
66

7-
export make_suite
7+
include("./Models.jl")
8+
using .Models: Models
9+
10+
export Models, make_suite
811

912
"""
1013
make_suite(model, varinfo_choice::Symbol, adbackend::Symbol)
@@ -13,7 +16,7 @@ Create a benchmark suite for `model` using the selected varinfo type and AD back
1316
Available varinfo choices:
1417
• `:untyped` → uses `VarInfo()`
1518
• `:typed` → uses `VarInfo(model)`
16-
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(free_nt)`
19+
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
1720
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
1821
1922
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
@@ -22,14 +25,13 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol)
2225
suite = BenchmarkGroup()
2326

2427
vi = if varinfo_choice == :untyped
25-
v = VarInfo()
26-
model(v)
27-
v
28+
vi = VarInfo()
29+
model(vi)
30+
vi
2831
elseif varinfo_choice == :typed
2932
VarInfo(model)
3033
elseif varinfo_choice == :simple_namedtuple
31-
free_nt = NamedTuple{(:m,)}(model()) # Extract only the free parameter(s)
32-
SimpleVarInfo{Float64}(free_nt)
34+
SimpleVarInfo{Float64}(model())
3335
elseif varinfo_choice == :simple_dict
3436
retvals = model()
3537
vns = [VarName{k}() for k in keys(retvals)]
@@ -39,7 +41,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol)
3941
end
4042

4143
# Add the AD benchmarking suite.
42-
suite["AD_Benchmarking"] = make_turing_suite(
44+
suite = make_turing_suite(
4345
model;
4446
adbackends=[adbackend],
4547
varinfo=vi,

benchmarks/src/Models.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Models for benchmarking Turing.jl.
3+
4+
Each model returns a NamedTuple of all the random variables in the model that are not
5+
observed (this is used for constructing SimpleVarInfos).
6+
"""
7+
module Models
8+
9+
using DynamicPPL: @model, to_submodel
10+
using Distributions:
11+
Categorical,
12+
Dirichlet,
13+
Exponential,
14+
Gamma,
15+
LKJCholesky,
16+
MatrixBeta,
17+
Normal,
18+
logpdf,
19+
product_distribution,
20+
truncated
21+
22+
export simple_assume_observe_non_model,
23+
simple_assume_observe, smorgasbord, loop_univariate, multivariate, parent, dynamic, lda
24+
25+
# This one is like simple_assume_observe, but explicitly does not use DynamicPPL.
26+
# Other runtimes are normalised by this one's runtime.
27+
function simple_assume_observe_non_model(x, obs)
28+
logp = logdf(x, Normal())
29+
logp += logpdf(obs, Normal(x, 1))
30+
return logp
31+
end
32+
33+
@model function simple_assume_observe(obs)
34+
x ~ Normal()
35+
obs ~ Normal(x, 1)
36+
return (; x=x)
37+
end
38+
39+
@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
40+
@assert length(x) == length(y)
41+
m ~ truncated(Normal(); lower=0)
42+
means ~ product_distribution(fill(Exponential(m), length(x)))
43+
stds = TV(undef, length(x))
44+
stds .~ Gamma(1, 1)
45+
for i in 1:length(x)
46+
x[i] ~ Normal(means[i], stds[i])
47+
end
48+
y ~ product_distribution([Normal(means[i], stds[i]) for i in 1:length(x)])
49+
0.0 ~ Normal(sum(y), 1)
50+
return (; m=m, means=means, stds=stds)
51+
end
52+
53+
@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
54+
a = TV(undef, num_dims)
55+
o = TV(undef, num_dims)
56+
for i in 1:num_dims
57+
a[i] ~ Normal(0, 1)
58+
end
59+
m = sum(a)
60+
for i in 1:num_dims
61+
o[i] ~ Normal(m, 1)
62+
end
63+
return (; a=a)
64+
end
65+
66+
@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
67+
a = TV(undef, num_dims)
68+
o = TV(undef, num_dims)
69+
a ~ product_distribution(fill(Normal(0, 1), num_dims))
70+
m = sum(a)
71+
o ~ product_distribution(fill(Normal(m, 1), num_dims))
72+
return (; a=a)
73+
end
74+
75+
@model function sub()
76+
x ~ Normal()
77+
return x
78+
end
79+
80+
@model function parent(y)
81+
x ~ to_submodel(sub())
82+
y ~ Normal(x, 1)
83+
return (; x=x)
84+
end
85+
86+
@model function dynamic(::Type{T}=Vector{Float64}) where {T}
87+
eta ~ truncated(Normal(); lower=0.0)
88+
mat1 ~ LKJCholesky(4, eta)
89+
mat2 ~ MatrixBeta(5, 6.0, 8.0)
90+
dim = eta > 0.2 ? 2 : 3
91+
vec = T(undef, dim)
92+
vec .~ truncated(Exponential(0.5); lower=0.0, upper=1.0)
93+
return (; eta=eta, mat1=mat1, mat2=mat2, vec=vec)
94+
end
95+
96+
@model function lda(K, d, w)
97+
V = length(unique(w))
98+
D = length(unique(d))
99+
N = length(d)
100+
@assert length(w) == N
101+
102+
ϕ = Vector{Vector{Real}}(undef, K)
103+
for i in 1:K
104+
ϕ[i] ~ Dirichlet(ones(V) / V)
105+
end
106+
107+
θ = Vector{Vector{Real}}(undef, D)
108+
for i in 1:D
109+
θ[i] ~ Dirichlet(ones(K) / K)
110+
end
111+
112+
z = zeros(Int, N)
113+
114+
for i in 1:N
115+
z[i] ~ Categorical(θ[d[i]])
116+
w[i] ~ Categorical(ϕ[d[i]])
117+
end
118+
return (; ϕ=ϕ, θ=θ, z=z)
119+
end
120+
121+
end

0 commit comments

Comments
 (0)