1
- const DISTMSG = " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
2
- " Distributions."
3
-
4
1
const INTERNALNAMES = (:__model__ , :__sampler__ , :__context__ , :__varinfo__ , :__rng__ )
5
2
const DEPRECATED_INTERNALNAMES = (:_model , :_sampler , :_context , :_varinfo , :_rng )
6
3
38
35
# failsafe: a literal is never an assumption
39
36
isassumption (expr) = :(false )
40
37
38
+ """
39
+ check_tilde_rhs(x)
40
+
41
+ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of
42
+ `Distributions`, then return `x`.
43
+ """
44
+ function check_tilde_rhs (@nospecialize (x))
45
+ return throw (ArgumentError (
46
+ " the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s"
47
+ ))
48
+ end
49
+ check_tilde_rhs (x:: Distribution ) = x
50
+ check_tilde_rhs (x:: AbstractArray{<:Distribution} ) = x
51
+
41
52
# ################
42
53
# Main Compiler #
43
54
# ################
@@ -225,34 +236,47 @@ Generate an `observe` expression for data variables and `assume` expression for
225
236
variables.
226
237
"""
227
238
function generate_tilde (left, right)
228
- @gensym tmpright
229
- top = [:($ tmpright = $ right),
230
- :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
231
- || throw (ArgumentError ($ DISTMSG)))]
232
-
233
- if left isa Symbol || left isa Expr
234
- @gensym out vn inds isassumption
235
- push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
236
-
239
+ # If the LHS is a literal, it is always an observation
240
+ if ! (left isa Symbol || left isa Expr)
237
241
return quote
238
- $ (top... )
239
- $ isassumption = $ (DynamicPPL. isassumption (left))
240
- if $ isassumption
241
- $ left = $ (DynamicPPL. tilde_assume)(
242
- __rng__, __context__, __sampler__, $ tmpright, $ vn, $ inds, __varinfo__
243
- )
244
- else
245
- $ (DynamicPPL. tilde_observe)(
246
- __context__, __sampler__, $ tmpright, $ left, $ vn, $ inds, __varinfo__
247
- )
248
- end
242
+ $ (DynamicPPL. tilde_observe)(
243
+ __context__,
244
+ __sampler__,
245
+ $ (DynamicPPL. check_tilde_rhs)($ right),
246
+ $ left,
247
+ __varinfo__,
248
+ )
249
249
end
250
250
end
251
251
252
- # If the LHS is a literal, it is always an observation
252
+ # Otherwise it is determined by the model or its value,
253
+ # if the LHS represents an observation
254
+ @gensym vn inds isassumption
253
255
return quote
254
- $ (top... )
255
- $ (DynamicPPL. tilde_observe)(__context__, __sampler__, $ tmpright, $ left, __varinfo__)
256
+ $ vn = $ (varname (left))
257
+ $ inds = $ (vinds (left))
258
+ $ isassumption = $ (DynamicPPL. isassumption (left))
259
+ if $ isassumption
260
+ $ left = $ (DynamicPPL. tilde_assume)(
261
+ __rng__,
262
+ __context__,
263
+ __sampler__,
264
+ $ (DynamicPPL. check_tilde_rhs)($ right),
265
+ $ vn,
266
+ $ inds,
267
+ __varinfo__,
268
+ )
269
+ else
270
+ $ (DynamicPPL. tilde_observe)(
271
+ __context__,
272
+ __sampler__,
273
+ $ (DynamicPPL. check_tilde_rhs)($ right),
274
+ $ left,
275
+ $ vn,
276
+ $ inds,
277
+ __varinfo__,
278
+ )
279
+ end
256
280
end
257
281
end
258
282
@@ -262,34 +286,48 @@ end
262
286
Generate the expression that replaces `left .~ right` in the model body.
263
287
"""
264
288
function generate_dot_tilde (left, right)
265
- @gensym tmpright
266
- top = [:($ tmpright = $ right),
267
- :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
268
- || throw (ArgumentError ($ DISTMSG)))]
269
-
270
- if left isa Symbol || left isa Expr
271
- @gensym out vn inds isassumption
272
- push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
273
-
289
+ # If the LHS is a literal, it is always an observation
290
+ if ! (left isa Symbol || left isa Expr)
274
291
return quote
275
- $ (top... )
276
- $ isassumption = $ (DynamicPPL. isassumption (left)) || $ left === missing
277
- if $ isassumption
278
- $ left .= $ (DynamicPPL. dot_tilde_assume)(
279
- __rng__, __context__, __sampler__, $ tmpright, $ left, $ vn, $ inds, __varinfo__
280
- )
281
- else
282
- $ (DynamicPPL. dot_tilde_observe)(
283
- __context__, __sampler__, $ tmpright, $ left, $ vn, $ inds, __varinfo__
284
- )
285
- end
292
+ $ (DynamicPPL. dot_tilde_observe)(
293
+ __context__,
294
+ __sampler__,
295
+ $ (DynamicPPL. check_tilde_rhs)($ right),
296
+ $ left,
297
+ __varinfo__,
298
+ )
286
299
end
287
300
end
288
301
289
- # If the LHS is a literal, it is always an observation
302
+ # Otherwise it is determined by the model or its value,
303
+ # if the LHS represents an observation
304
+ @gensym vn inds isassumption
290
305
return quote
291
- $ (top... )
292
- $ (DynamicPPL. dot_tilde_observe)(__context__, __sampler__, $ tmpright, $ left, __varinfo__)
306
+ $ vn = $ (varname (left))
307
+ $ inds = $ (vinds (left))
308
+ $ isassumption = $ (DynamicPPL. isassumption (left))
309
+ if $ isassumption
310
+ $ left .= $ (DynamicPPL. dot_tilde_assume)(
311
+ __rng__,
312
+ __context__,
313
+ __sampler__,
314
+ $ (DynamicPPL. check_tilde_rhs)($ right),
315
+ $ left,
316
+ $ vn,
317
+ $ inds,
318
+ __varinfo__,
319
+ )
320
+ else
321
+ $ (DynamicPPL. dot_tilde_observe)(
322
+ __context__,
323
+ __sampler__,
324
+ $ (DynamicPPL. check_tilde_rhs)($ right),
325
+ $ left,
326
+ $ vn,
327
+ $ inds,
328
+ __varinfo__,
329
+ )
330
+ end
293
331
end
294
332
end
295
333
0 commit comments