Skip to content

Commit f4e1ccb

Browse files
committed
Add benchmarks
1 parent 88ca4ac commit f4e1ccb

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

benchmark/model_creation.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ function benchmark_model_creation()
99

1010
SUITE["model creation"]["state space (length)"] = benchmark_state_space_model()
1111
SUITE["model creation"]["hierarchical (depth)"] = benchmark_hierarchical_model()
12+
SUITE["model creation"]["recursive (depth)"] = benchmark_recursive_model()
13+
SUITE["model creation"]["neural net (hidden dim size)"] = benchmark_neural_net_model()
1214

1315
return SUITE
1416
end
@@ -106,3 +108,103 @@ function create_hierarchical_model(length::Int, constraints = nothing)
106108
return (; κ = κ, ω = ω, θ = θ, x_begin = x_begin)
107109
end
108110
end
111+
112+
## Recursive model benchmarks
113+
114+
function benchmark_recursive_model()
115+
SUITE = BenchmarkGroup()
116+
117+
for length in 100 .* range(1, stop = 3)
118+
# This SUITE benchmarks how long it takes to create a recursive model with depth `n` and length `n` and default constraints
119+
SUITE["default constraints", length] = @benchmarkable create_recursive_model($length) evals = 1
120+
# This SUITE benchmarks how long it takes to create a recursive model with depth `n` and length `n` and mean field constraints
121+
SUITE["mean field constraints", length] = @benchmarkable create_recursive_model($length, $(MeanField())) evals = 1
122+
end
123+
124+
return SUITE
125+
end
126+
127+
@model function recursive_model(μ, y, depth)
128+
if depth == 0
129+
y ~ Normal(0, 1)
130+
else
131+
μ ~ Normal(y, 1)
132+
μ ~ recursive_model(y = y, depth = depth - 1)
133+
end
134+
end
135+
136+
function create_recursive_model(depth::Int, constraints = nothing)
137+
plugins = if isnothing(constraints)
138+
GraphPPL.PluginsCollection()
139+
else
140+
GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints))
141+
end
142+
return GraphPPL.create_model(GraphPPL.with_plugins(recursive_model(depth = depth), plugins)) do model, ctx
143+
y = GraphPPL.getorcreate!(model, ctx, :y, nothing)
144+
μ = GraphPPL.getorcreate!(model, ctx, , nothing)
145+
return (; y = y, μ = μ)
146+
end
147+
end
148+
149+
## Neural net model benchmarks
150+
151+
function benchmark_neural_net_model()
152+
SUITE = BenchmarkGroup()
153+
154+
for length in 2 .^ range(2, stop = 7)
155+
# This SUITE benchmarks how long it takes to create a neural network model with `8` layers and hidden dimension size `n` and default constraints
156+
SUITE["default constraints", length] = @benchmarkable create_neural_net_model($length) evals = 1
157+
# This SUITE benchmarks how long it takes to create a neural network model with `8` layers and hidden dimension size `n` and mean field constraints
158+
# SUITE["mean field constraints", length] = @benchmarkable create_neural_net_model($length, $(MeanField())) evals = 1
159+
end
160+
161+
return SUITE
162+
end
163+
164+
function dot end
165+
function relu end
166+
167+
@model function neuron(in, out)
168+
local w
169+
for i in 1:(length(in))
170+
w[i] ~ Normal(0.0, 1.0)
171+
end
172+
bias ~ Normal(0.0, 1.0)
173+
unactivated := dot(in, w) + bias
174+
out := relu(unactivated)
175+
end
176+
177+
@model function neural_network_layer(in, out, n)
178+
for i in 1:n
179+
out[i] ~ neuron(in = in)
180+
end
181+
end
182+
183+
@model function neural_net(in, out, h_size)
184+
local softin
185+
for i in 1:length(in)
186+
softin[i] ~ Normal(in[i], 1.0)
187+
end
188+
h1 ~ neural_network_layer(in = softin, n = h_size)
189+
h2 ~ neural_network_layer(in = h1, n = h_size)
190+
h3 ~ neural_network_layer(in = h2, n = h_size)
191+
h4 ~ neural_network_layer(in = h3, n = h_size)
192+
h5 ~ neural_network_layer(in = h4, n = h_size)
193+
h6 ~ neural_network_layer(in = h5, n = h_size)
194+
h7 ~ neural_network_layer(in = h6, n = h_size)
195+
h8 ~ neural_network_layer(in = h7, n = h_size)
196+
out ~ neural_network_layer(in = h8, n = 5)
197+
end
198+
199+
function create_neural_net_model(n::Int, constraints = nothing)
200+
plugins = if isnothing(constraints)
201+
GraphPPL.PluginsCollection()
202+
else
203+
GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints))
204+
end
205+
return GraphPPL.create_model(GraphPPL.with_plugins(neural_net(h_size = n), plugins)) do model, ctx
206+
in = GraphPPL.datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :in, rand(10))
207+
out = GraphPPL.datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :out, rand(5))
208+
return (; in = in, out = out)
209+
end
210+
end

0 commit comments

Comments
 (0)