@@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr})
18
18
vn = gensym (:vn )
19
19
20
20
return quote
21
- let $ vn = $ (varname (expr))
21
+ let $ vn = $ (AbstractPPL . drop_escape ( varname (expr) ))
22
22
if $ (DynamicPPL. contextual_isassumption)(__context__, $ vn)
23
23
# Considered an assumption by `__context__` which means either:
24
24
# 1. We hit the default implementation, e.g. using `DefaultContext`,
@@ -133,17 +133,17 @@ variables.
133
133
134
134
# Example
135
135
```jldoctest; setup=:(using Distributions, LinearAlgebra)
136
- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string( vns[end])
137
- " x[:,2]"
136
+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
137
+ x[:,2]
138
138
139
- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:] )); string( vns[end])
140
- "x[:][ 1,2]"
139
+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
140
+ x[ 1,2]
141
141
142
- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3 ), @varname(x[1 ])); string( vns[end])
143
- "x[1][3]"
142
+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2 ), @varname(x[: ])); vns[end]
143
+ x[:][1,2]
144
144
145
- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string( vns[end])
146
- " x[1,2,3]"
145
+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1] )); vns[end]
146
+ x[1][3]
147
147
```
148
148
"""
149
149
unwrap_right_left_vns (right, left, vns) = right, left, vns
@@ -158,7 +158,7 @@ function unwrap_right_left_vns(
158
158
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
159
159
# and we therefore add the `Colon()` below.
160
160
vns = map (axes (left, 2 )) do i
161
- return VarName (vn, (vn . indexing ... , ( Colon (), i) ))
161
+ return vn ∘ Setfield . IndexLens (( Colon (), i))
162
162
end
163
163
return unwrap_right_left_vns (right, left, vns)
164
164
end
@@ -168,7 +168,7 @@ function unwrap_right_left_vns(
168
168
vn:: VarName ,
169
169
)
170
170
vns = map (CartesianIndices (left)) do i
171
- return VarName (vn, (vn . indexing ... , Tuple (i) ))
171
+ return vn ∘ Setfield . IndexLens ( Tuple (i))
172
172
end
173
173
return unwrap_right_left_vns (right, left, vns)
174
174
end
@@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
317
317
# Do not touch interpolated expressions
318
318
expr. head === :$ && return expr. args[1 ]
319
319
320
+ # Do we don't want escaped expressions because we unfortunately
321
+ # escape the entire body afterwards.
322
+ Meta. isexpr (expr, :escape ) && return generate_mainbody (mod, found, expr. args[1 ], warn)
323
+
320
324
# If it's a macro, we expand it
321
325
if Meta. isexpr (expr, :macrocall )
322
326
return generate_mainbody! (mod, found, macroexpand (mod, expr; recursive= true ), warn)
@@ -349,38 +353,36 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
349
353
return Expr (expr. head, map (x -> generate_mainbody! (mod, found, x, warn), expr. args)... )
350
354
end
351
355
356
+ function generate_tilde_literal (left, right)
357
+ # If the LHS is a literal, it is always an observation
358
+ return quote
359
+ $ (DynamicPPL. tilde_observe!)(
360
+ __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
361
+ )
362
+ end
363
+ end
364
+
352
365
"""
353
366
generate_tilde(left, right)
354
367
355
368
Generate an `observe` expression for data variables and `assume` expression for parameter
356
369
variables.
357
370
"""
358
371
function generate_tilde (left, right)
359
- # If the LHS is a literal, it is always an observation
360
- if isliteral (left)
361
- return quote
362
- $ (DynamicPPL. tilde_observe!)(
363
- __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
364
- )
365
- end
366
- end
372
+ isliteral (left) && return generate_tilde_literal (left, right)
367
373
368
374
# Otherwise it is determined by the model or its value,
369
375
# if the LHS represents an observation
370
- @gensym vn inds isassumption
376
+ @gensym vn isassumption
377
+
378
+ # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379
+ # that in DynamicPPL we the entire function body. Instead we should be
380
+ # more selective with our escape. Until that's the case, we remove them all.
371
381
return quote
372
- $ vn = $ (varname (left))
373
- $ inds = $ (vinds (left))
382
+ $ vn = $ (AbstractPPL. drop_escape (varname (left)))
374
383
$ isassumption = $ (DynamicPPL. isassumption (left))
375
384
if $ isassumption
376
- $ left = $ (DynamicPPL. tilde_assume!)(
377
- __context__,
378
- $ (DynamicPPL. unwrap_right_vn)(
379
- $ (DynamicPPL. check_tilde_rhs)($ right), $ vn
380
- ). .. ,
381
- $ inds,
382
- __varinfo__,
383
- )
385
+ $ (generate_tilde_assume (left, right, vn))
384
386
else
385
387
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
386
388
if ! $ (DynamicPPL. inargnames)($ vn, __model__)
@@ -392,44 +394,46 @@ function generate_tilde(left, right)
392
394
$ (DynamicPPL. check_tilde_rhs)($ right),
393
395
$ (maybe_view (left)),
394
396
$ vn,
395
- $ inds,
396
397
__varinfo__,
397
398
)
398
399
end
399
400
end
400
401
end
401
402
403
+ function generate_tilde_assume (left, right, vn)
404
+ expr = :(
405
+ $ left = $ (DynamicPPL. tilde_assume!)(
406
+ __context__,
407
+ $ (DynamicPPL. unwrap_right_vn)($ (DynamicPPL. check_tilde_rhs)($ right), $ vn). .. ,
408
+ __varinfo__,
409
+ )
410
+ )
411
+
412
+ return if left isa Expr
413
+ AbstractPPL. drop_escape (
414
+ Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415
+ )
416
+ else
417
+ return expr
418
+ end
419
+ end
420
+
402
421
"""
403
422
generate_dot_tilde(left, right)
404
423
405
424
Generate the expression that replaces `left .~ right` in the model body.
406
425
"""
407
426
function generate_dot_tilde (left, right)
408
- # If the LHS is a literal, it is always an observation
409
- if isliteral (left)
410
- return quote
411
- $ (DynamicPPL. dot_tilde_observe!)(
412
- __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
413
- )
414
- end
415
- end
427
+ isliteral (left) && return generate_tilde_literal (left, right)
416
428
417
429
# Otherwise it is determined by the model or its value,
418
430
# if the LHS represents an observation
419
- @gensym vn inds isassumption
431
+ @gensym vn isassumption
420
432
return quote
421
- $ vn = $ (varname (left))
422
- $ inds = $ (vinds (left))
433
+ $ vn = $ (AbstractPPL. drop_escape (varname (left)))
423
434
$ isassumption = $ (DynamicPPL. isassumption (left))
424
435
if $ isassumption
425
- $ left .= $ (DynamicPPL. dot_tilde_assume!)(
426
- __context__,
427
- $ (DynamicPPL. unwrap_right_left_vns)(
428
- $ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
429
- ). .. ,
430
- $ inds,
431
- __varinfo__,
432
- )
436
+ $ (generate_dot_tilde_assume (left, right, vn))
433
437
else
434
438
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
435
439
if ! $ (DynamicPPL. inargnames)($ vn, __model__)
@@ -441,13 +445,27 @@ function generate_dot_tilde(left, right)
441
445
$ (DynamicPPL. check_tilde_rhs)($ right),
442
446
$ (maybe_view (left)),
443
447
$ vn,
444
- $ inds,
445
448
__varinfo__,
446
449
)
447
450
end
448
451
end
449
452
end
450
453
454
+ function generate_dot_tilde_assume (left, right, vn)
455
+ # We don't need to use `Setfield.@set` here since
456
+ # `.=` is always going to be inplace + needs `left` to
457
+ # be something that supports `.=`.
458
+ return :(
459
+ $ left .= $ (DynamicPPL. dot_tilde_assume!)(
460
+ __context__,
461
+ $ (DynamicPPL. unwrap_right_left_vns)(
462
+ $ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
463
+ ). .. ,
464
+ __varinfo__,
465
+ )
466
+ )
467
+ end
468
+
451
469
const FloatOrArrayType = Type{<: Union{AbstractFloat,AbstractArray} }
452
470
hasmissing (T:: Type{<:AbstractArray{TA}} ) where {TA<: AbstractArray } = hasmissing (TA)
453
471
hasmissing (T:: Type{<:AbstractArray{>:Missing}} ) = true
0 commit comments