69
69
70
70
To generate a `Model`, call `model_generator(x_value)`.
71
71
"""
72
- macro model (input_expr)
73
- Base. replace_ref_end! (input_expr) |> build_model_info |> replace_tilde! |> replace_vi! |>
74
- replace_logpdf! |> replace_sampler! |> build_output
72
+ macro model (expr)
73
+ esc (model (expr))
74
+ end
75
+
76
+ function model (expr)
77
+ modelinfo = build_model_info (expr)
78
+
79
+ # Generate main body
80
+ modelinfo[:main_body ] = generate_mainbody (modelinfo)
81
+
82
+ return build_output (modelinfo)
75
83
end
76
84
77
85
"""
@@ -170,62 +178,55 @@ function build_model_info(input_expr)
170
178
return model_info
171
179
end
172
180
173
-
174
181
"""
175
- replace_vi!(model_info )
182
+ generate_mainbody([expr, ]modelinfo )
176
183
177
- Replaces `@varinfo()` expressions with a handle to the `VarInfo` struct .
184
+ Generate the body of the main evaluation function .
178
185
"""
179
- function replace_vi! (model_info)
180
- ex = model_info[:main_body ]
181
- vi = model_info[:main_body_names ][:vi ]
182
- ex = MacroTools. postwalk (ex) do x
183
- if @capture (x, @varinfo ())
184
- vi
185
- else
186
- x
187
- end
188
- end
189
- model_info[:main_body ] = ex
190
- return model_info
191
- end
186
+ generate_mainbody (modelinfo) = generate_mainbody (modelinfo[:main_body ], modelinfo)
192
187
193
- """
194
- replace_logpdf!(model_info)
188
+ generate_mainbody (x, modelinfo) = x
189
+ function generate_mainbody (expr:: Expr , modelinfo)
190
+ # Do not touch interpolated expressions
191
+ expr. head === :$ && return expr. args[1 ]
195
192
196
- Replaces `@logpdf()` expressions with the value of the accumulated `logpdf` in the `VarInfo` struct.
197
- """
198
- function replace_logpdf! (model_info)
199
- ex = model_info[:main_body ]
200
- vi = model_info[:main_body_names ][:vi ]
201
- ex = MacroTools. postwalk (ex) do x
202
- if @capture (x, @logpdf ())
203
- :($ (vi). logp[])
204
- else
205
- x
193
+ # Apply the `@.` macro first.
194
+ if Meta. isexpr (expr, :macrocall ) && length (expr. args) > 1 &&
195
+ expr. args[1 ] === Symbol (" @__dot__" )
196
+ return generate_mainbody (Base. Broadcast. __dot__ (expr. args[end ]), modelinfo)
197
+ end
198
+
199
+ # Modify macro calls.
200
+ if Meta. isexpr (expr, :macrocall ) && ! isempty (expr. args)
201
+ name = expr. args[1 ]
202
+ if name === Symbol (" @varinfo" )
203
+ return modelinfo[:main_body_names ][:vi ]
204
+ elseif name === Symbol (" @logpdf" )
205
+ return :($ (modelinfo[:main_body_names ][:vi ]). logp[])
206
+ elseif name === Symbol (" @sampler" )
207
+ return :($ (modelinfo[:main_body_names ][:sampler ]))
206
208
end
207
209
end
208
- model_info[:main_body ] = ex
209
- return model_info
210
- end
211
210
212
- """
213
- replace_sampler!(model_info)
211
+ # Modify dotted tilde operators.
212
+ args_dottilde = getargs_dottilde (expr)
213
+ if args_dottilde != = nothing
214
+ L, R = args_dottilde
215
+ return Base. remove_linenums! (generate_dot_tilde (generate_mainbody (L, modelinfo),
216
+ generate_mainbody (R, modelinfo),
217
+ modelinfo))
218
+ end
214
219
215
- Replaces `@sampler()` expressions with a handle to the sampler struct.
216
- """
217
- function replace_sampler! (model_info)
218
- ex = model_info[:main_body ]
219
- spl = model_info[:main_body_names ][:sampler ]
220
- ex = MacroTools. postwalk (ex) do x
221
- if @capture (x, @sampler ())
222
- spl
223
- else
224
- x
225
- end
220
+ # Modify tilde operators.
221
+ args_tilde = getargs_tilde (expr)
222
+ if args_tilde != = nothing
223
+ L, R = args_tilde
224
+ return Base. remove_linenums! (generate_tilde (generate_mainbody (L, modelinfo),
225
+ generate_mainbody (R, modelinfo),
226
+ modelinfo))
226
227
end
227
- model_info[ :main_body ] = ex
228
- return model_info
228
+
229
+ return Expr (expr . head, map (x -> generate_mainbody (x, modelinfo), expr . args) ... )
229
230
end
230
231
231
232
"""
@@ -443,7 +444,7 @@ function build_output(model_info)
443
444
generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
444
445
model_gen_constructor = :($ (DynamicPPL. ModelGen){$ (Tuple (arg_syms))}($ generator, $ defaults_nt))
445
446
446
- ex = quote
447
+ return quote
447
448
function $evaluator (
448
449
$ model:: $ (DynamicPPL. Model),
449
450
$ vi:: $ (DynamicPPL. VarInfo),
@@ -459,8 +460,6 @@ function build_output(model_info)
459
460
460
461
$ model_gen = $ model_gen_constructor
461
462
end
462
-
463
- return esc (ex)
464
463
end
465
464
466
465
0 commit comments