@@ -9,6 +9,8 @@ function benchmark_model_creation()
99
1010 SUITE[" model creation" ][" state space (length)" ] = benchmark_state_space_model ()
1111 SUITE[" model creation" ][" hierarchical (depth)" ] = benchmark_hierarchical_model ()
12+ SUITE[" model creation" ][" recursive (depth)" ] = benchmark_recursive_model ()
13+ SUITE[" model creation" ][" neural net (hidden dim size)" ] = benchmark_neural_net_model ()
1214
1315 return SUITE
1416end
@@ -106,3 +108,103 @@ function create_hierarchical_model(length::Int, constraints = nothing)
106108 return (; κ = κ, ω = ω, θ = θ, x_begin = x_begin)
107109 end
108110end
111+
112+ # # Recursive model benchmarks
113+
114+ function benchmark_recursive_model ()
115+ SUITE = BenchmarkGroup ()
116+
117+ for length in 100 .* range (1 , stop = 3 )
118+ # This SUITE benchmarks how long it takes to create a recursive model with depth `n` and length `n` and default constraints
119+ SUITE[" default constraints" , length] = @benchmarkable create_recursive_model ($ length) evals = 1
120+ # This SUITE benchmarks how long it takes to create a recursive model with depth `n` and length `n` and mean field constraints
121+ SUITE[" mean field constraints" , length] = @benchmarkable create_recursive_model ($ length, $ (MeanField ())) evals = 1
122+ end
123+
124+ return SUITE
125+ end
126+
127+ @model function recursive_model (μ, y, depth)
128+ if depth == 0
129+ y ~ Normal (0 , 1 )
130+ else
131+ μ ~ Normal (y, 1 )
132+ μ ~ recursive_model (y = y, depth = depth - 1 )
133+ end
134+ end
135+
136+ function create_recursive_model (depth:: Int , constraints = nothing )
137+ plugins = if isnothing (constraints)
138+ GraphPPL. PluginsCollection ()
139+ else
140+ GraphPPL. PluginsCollection (GraphPPL. VariationalConstraintsPlugin (constraints))
141+ end
142+ return GraphPPL. create_model (GraphPPL. with_plugins (recursive_model (depth = depth), plugins)) do model, ctx
143+ y = GraphPPL. getorcreate! (model, ctx, :y , nothing )
144+ μ = GraphPPL. getorcreate! (model, ctx, :μ , nothing )
145+ return (; y = y, μ = μ)
146+ end
147+ end
148+
149+ # # Neural net model benchmarks
150+
151+ function benchmark_neural_net_model ()
152+ SUITE = BenchmarkGroup ()
153+
154+ for length in 2 .^ range (2 , stop = 7 )
155+ # This SUITE benchmarks how long it takes to create a neural network model with `8` layers and hidden dimension size `n` and default constraints
156+ SUITE[" default constraints" , length] = @benchmarkable create_neural_net_model ($ length) evals = 1
157+ # This SUITE benchmarks how long it takes to create a neural network model with `8` layers and hidden dimension size `n` and mean field constraints
158+ # SUITE["mean field constraints", length] = @benchmarkable create_neural_net_model($length, $(MeanField())) evals = 1
159+ end
160+
161+ return SUITE
162+ end
163+
164+ function dot end
165+ function relu end
166+
167+ @model function neuron (in, out)
168+ local w
169+ for i in 1 : (length (in))
170+ w[i] ~ Normal (0.0 , 1.0 )
171+ end
172+ bias ~ Normal (0.0 , 1.0 )
173+ unactivated := dot (in, w) + bias
174+ out := relu (unactivated)
175+ end
176+
177+ @model function neural_network_layer (in, out, n)
178+ for i in 1 : n
179+ out[i] ~ neuron (in = in)
180+ end
181+ end
182+
183+ @model function neural_net (in, out, h_size)
184+ local softin
185+ for i in 1 : length (in)
186+ softin[i] ~ Normal (in[i], 1.0 )
187+ end
188+ h1 ~ neural_network_layer (in = softin, n = h_size)
189+ h2 ~ neural_network_layer (in = h1, n = h_size)
190+ h3 ~ neural_network_layer (in = h2, n = h_size)
191+ h4 ~ neural_network_layer (in = h3, n = h_size)
192+ h5 ~ neural_network_layer (in = h4, n = h_size)
193+ h6 ~ neural_network_layer (in = h5, n = h_size)
194+ h7 ~ neural_network_layer (in = h6, n = h_size)
195+ h8 ~ neural_network_layer (in = h7, n = h_size)
196+ out ~ neural_network_layer (in = h8, n = 5 )
197+ end
198+
199+ function create_neural_net_model (n:: Int , constraints = nothing )
200+ plugins = if isnothing (constraints)
201+ GraphPPL. PluginsCollection ()
202+ else
203+ GraphPPL. PluginsCollection (GraphPPL. VariationalConstraintsPlugin (constraints))
204+ end
205+ return GraphPPL. create_model (GraphPPL. with_plugins (neural_net (h_size = n), plugins)) do model, ctx
206+ in = GraphPPL. datalabel (model, ctx, GraphPPL. NodeCreationOptions (kind = :data ), :in , rand (10 ))
207+ out = GraphPPL. datalabel (model, ctx, GraphPPL. NodeCreationOptions (kind = :data ), :out , rand (5 ))
208+ return (; in = in, out = out)
209+ end
210+ end
0 commit comments