11const INTERNALNAMES = (:__model__ , :__varinfo__ )
22
3+ drop_escape (x) = x
4+ function drop_escape (expr:: Expr )
5+ Meta. isexpr (expr, :escape ) && return drop_escape (expr. args[1 ])
6+ return Expr (expr. head, map (x -> drop_escape (x), expr. args)... )
7+ end
8+
9+ get_top_level_symbol (expr:: Symbol ) = expr
10+ function get_top_level_symbol (expr:: Expr )
11+ # TODO (penelopeysm): what about Meta.isexpr(expr, :$)?
12+ if Meta. isexpr (expr, :ref )
13+ return get_top_level_symbol (expr. args[1 ])
14+ elseif Meta. isexpr (expr, :.)
15+ return get_top_level_symbol (expr. args[1 ])
16+ else
17+ error (" unreachable" )
18+ end
19+ end
20+
321"""
422 need_concretize(expr)
523
6- Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
7- requires a dynamic optic.
24+ Determine whether `expr` defines a VarName that needs to be concretised.
825
9- # Examples
26+ Note that, although we parse VarNames using our own lenses, Accessors.need_dynamic_optic is
27+ actually still 'good enough' to determine whether we need to concretise or not.
1028
11- ```jldoctest; setup=:(using Accessors)
12- julia> DynamicPPL.need_concretize(:(x[1, :]))
13- true
14-
15- julia> DynamicPPL.need_concretize(:(x[1, end]))
16- true
17-
18- julia> DynamicPPL.need_concretize(:(x[1, 1]))
19- false
29+ Eventually, we can hopefully never concretise anything.
2030"""
2131function need_concretize (expr)
2232 return Accessors. need_dynamic_optic (expr) || begin
3242"""
3343 make_varname_expression(expr)
3444
35- Return a `VarName` based on `expr`, concretizing it if necessary .
45+ Return a `VarName` based on `expr`.
3646"""
3747function make_varname_expression (expr)
38- # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
39- # that in DynamicPPL we the entire function body. Instead we should be
40- # more selective with our escape. Until that's the case, we remove them all.
41- return AbstractPPL. drop_escape (varname (expr, need_concretize (expr)))
48+ # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact that in
49+ # DynamicPPL we the entire function body. Instead we should be more selective with our
50+ # escape. Until that's the case, we remove them all.
51+ # TODO (penelopeysm): We still concretise things, because PartialArray does not
52+ # understand dynamic indices. This is not necessarily a bad thing for performance, but
53+ # it would be nice to not NEED to have to do it. That would require shadow arrays. See
54+ # #1194.
55+ return drop_escape (AbstractPPL. varname (expr, need_concretize (expr)))
4256end
4357
4458"""
@@ -55,10 +69,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:
5569
5670When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
5771
58- If `vn` is specified, it will be assumed to refer to a expression which
59- evaluates to a `VarName`, and this will be used in the subsequent checks.
60- If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
61- used in its place.
72+ If `vn` is specified, it will be assumed to refer to a expression which evaluates to a
73+ `VarName`, and this will be used in the subsequent checks. If `vn` is not specified,
74+ `(@varname \$ expr)` will be used in its place.
6275"""
6376function isassumption (expr:: Union{Expr,Symbol} , vn= make_varname_expression (expr))
6477 return quote
@@ -221,9 +234,6 @@ variables.
221234
222235# Example
223236```jldoctest; setup=:(using Distributions, LinearAlgebra)
224- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
225- x[:, 2]
226-
227237julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
228238x[1, 2]
229239
@@ -241,31 +251,20 @@ end
241251function unwrap_right_left_vns (right:: NamedDist , left:: AbstractMatrix , :: VarName )
242252 return unwrap_right_left_vns (right. dist, left, right. name)
243253end
244- function unwrap_right_left_vns (
245- right:: MultivariateDistribution , left:: AbstractMatrix , vn:: VarName
246- )
247- # This an expression such as `x .~ MvNormal()` which we interpret as
248- # x[:, i] ~ MvNormal()
249- # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
250- # and we therefore add the `Colon()` below.
251- vns = map (axes (left, 2 )) do i
252- return AbstractPPL. concretize (Accessors. IndexLens ((Colon (), i)) ∘ vn, left)
253- end
254- return unwrap_right_left_vns (right, left, vns)
255- end
256254function unwrap_right_left_vns (
257255 right:: Union{Distribution,AbstractArray{<:Distribution}} ,
258256 left:: AbstractArray ,
259257 vn:: VarName ,
260258)
261259 vns = map (CartesianIndices (left)) do i
262- return Accessors. IndexLens (Tuple (i)) ∘ vn
260+ sym, optic = getsym (vn), getoptic (vn)
261+ return VarName {sym} (AbstractPPL. Index (Tuple (i), (;), AbstractPPL. Iden ()) ∘ optic)
263262 end
264263 return unwrap_right_left_vns (right, left, vns)
265264end
266265
267266resolve_varnames (vn:: VarName , _) = vn
268- resolve_varnames (vn :: VarName , dist:: NamedDist ) = dist. name
267+ resolve_varnames (:: VarName , dist:: NamedDist ) = dist. name
269268
270269# ################
271270# Main Compiler #
@@ -463,9 +462,18 @@ function generate_tilde_literal(left, right)
463462 end
464463end
465464
466- assign_or_set!! (lhs:: Symbol , rhs) = AbstractPPL. drop_escape (:($ lhs = $ rhs))
467- function assign_or_set!! (lhs:: Expr , rhs)
468- return AbstractPPL. drop_escape (:($ BangBang. @set!! $ lhs = $ rhs))
465+ assign_or_set!! (lhs:: Symbol , rhs, vn) = drop_escape (:($ lhs = $ rhs))
466+ function assign_or_set!! (lhs:: Expr , rhs, vn)
467+ left_top_sym = get_top_level_symbol (lhs)
468+ return drop_escape (
469+ :(
470+ $ left_top_sym = $ (Accessors. set)(
471+ $ left_top_sym,
472+ $ (AbstractPPL. with_mutation)($ (AbstractPPL. getoptic)($ vn)),
473+ $ rhs,
474+ )
475+ ),
476+ )
469477end
470478
471479"""
@@ -487,12 +495,13 @@ function generate_tilde(left, right)
487495 $ isassumption = $ (DynamicPPL. isassumption (left, vn))
488496 if $ (DynamicPPL. isfixed (left, vn))
489497 # $left may not be a simple varname, it might be x.a or x[1], in which case we
490- # need to use BangBang.@ set!! to safely set it.
498+ # need to use Accessors. set to safely set it.
491499 $ (assign_or_set!! (
492500 left,
493501 :($ (DynamicPPL. getfixed_nested)(
494502 __model__. context, $ (DynamicPPL. prefix)(__model__. context, $ vn)
495503 )),
504+ vn,
496505 ))
497506 elseif $ isassumption
498507 $ (generate_tilde_assume (left, dist, vn))
@@ -520,7 +529,7 @@ function generate_tilde(left, right)
520529 $ vn,
521530 __varinfo__,
522531 )
523- $ (assign_or_set!! (left, value))
532+ $ (assign_or_set!! (left, value, vn ))
524533 $ value
525534 end
526535 end
@@ -531,11 +540,17 @@ function generate_tilde_assume(left, right, vn)
531540 # with multiple arguments on the LHS, we need to capture the return-values
532541 # and then update the LHS variables one by one.
533542 @gensym value
534- expr = :($ left = $ value)
535- if left isa Expr
536- expr = AbstractPPL. drop_escape (
537- Accessors. setmacro (BangBang. prefermutation, expr; overwrite= true )
543+ expr = if left isa Expr # as opposed to Symbol
544+ left_top_sym = get_top_level_symbol (left)
545+ :(
546+ $ left_top_sym = $ (Accessors. set)(
547+ $ left_top_sym,
548+ $ (AbstractPPL. with_mutation)($ (AbstractPPL. getoptic)($ vn)),
549+ $ value,
550+ )
538551 )
552+ else
553+ :($ left = $ value)
539554 end
540555
541556 return quote
0 commit comments