Skip to content

Commit 65173b6

Browse files
authored
AbstractPPL 0.14 - new VarName optics & modifications to VNT (#1203)
I also split up `varnamedtuple.jl` because it's just massive. Also reorganised a bunch of code and deduplicated method names. Anything that takes an optic is called `_foo_optic` (where `foo` is one of `getindex`, `setindex` and `haskey`). That frees us up to use the original `Base.getindex`, `BangBang.setindex!!`, and `Base.haskey` for methods that take a plain old list of indices.' Note that this breaks a few tests. For example, colon indices are broken, because colons are always unconcretised now. Unfortunately, there's no way of fixing this _properly_ without #1194. However, we _could_ replicate the old concretisation mechanism of colons in DynamicPPL, which would make them continue to work (at least for now). This PR disables colon tests for the time being.
1 parent ea9b07a commit 65173b6

File tree

20 files changed

+1665
-1593
lines changed

20 files changed

+1665
-1593
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4646
[compat]
4747
ADTypes = "1"
4848
AbstractMCMC = "5.10"
49-
AbstractPPL = "0.13.1"
49+
AbstractPPL = "0.14"
5050
Accessors = "0.1"
5151
BangBang = "0.4.1"
5252
Bijectors = "0.15.11"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1515

1616
[compat]
1717
AbstractMCMC = "5"
18-
AbstractPPL = "0.13"
18+
AbstractPPL = "0.14"
1919
Accessors = "0.1"
2020
Distributions = "0.25"
2121
Documenter = "1"

src/compiler.jl

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
const INTERNALNAMES = (:__model__, :__varinfo__)
22

3+
drop_escape(x) = x
4+
function drop_escape(expr::Expr)
5+
Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1])
6+
return Expr(expr.head, map(x -> drop_escape(x), expr.args)...)
7+
end
8+
9+
get_top_level_symbol(expr::Symbol) = expr
10+
function get_top_level_symbol(expr::Expr)
11+
# TODO(penelopeysm): what about Meta.isexpr(expr, :$)?
12+
if Meta.isexpr(expr, :ref)
13+
return get_top_level_symbol(expr.args[1])
14+
elseif Meta.isexpr(expr, :.)
15+
return get_top_level_symbol(expr.args[1])
16+
else
17+
error("unreachable")
18+
end
19+
end
20+
321
"""
422
need_concretize(expr)
523
6-
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
7-
requires a dynamic optic.
24+
Determine whether `expr` defines a VarName that needs to be concretised.
825
9-
# Examples
26+
Note that, although we parse VarNames using our own lenses, Accessors.need_dynamic_optic is
27+
actually still 'good enough' to determine whether we need to concretise or not.
1028
11-
```jldoctest; setup=:(using Accessors)
12-
julia> DynamicPPL.need_concretize(:(x[1, :]))
13-
true
14-
15-
julia> DynamicPPL.need_concretize(:(x[1, end]))
16-
true
17-
18-
julia> DynamicPPL.need_concretize(:(x[1, 1]))
19-
false
29+
Eventually, we can hopefully never concretise anything.
2030
"""
2131
function need_concretize(expr)
2232
return Accessors.need_dynamic_optic(expr) || begin
@@ -32,13 +42,17 @@ end
3242
"""
3343
make_varname_expression(expr)
3444
35-
Return a `VarName` based on `expr`, concretizing it if necessary.
45+
Return a `VarName` based on `expr`.
3646
"""
3747
function make_varname_expression(expr)
38-
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
39-
# that in DynamicPPL we the entire function body. Instead we should be
40-
# more selective with our escape. Until that's the case, we remove them all.
41-
return AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))
48+
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact that in
49+
# DynamicPPL we the entire function body. Instead we should be more selective with our
50+
# escape. Until that's the case, we remove them all.
51+
# TODO(penelopeysm): We still concretise things, because PartialArray does not
52+
# understand dynamic indices. This is not necessarily a bad thing for performance, but
53+
# it would be nice to not NEED to have to do it. That would require shadow arrays. See
54+
# #1194.
55+
return drop_escape(AbstractPPL.varname(expr, need_concretize(expr)))
4256
end
4357

4458
"""
@@ -55,10 +69,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:
5569
5670
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
5771
58-
If `vn` is specified, it will be assumed to refer to a expression which
59-
evaluates to a `VarName`, and this will be used in the subsequent checks.
60-
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
61-
used in its place.
72+
If `vn` is specified, it will be assumed to refer to a expression which evaluates to a
73+
`VarName`, and this will be used in the subsequent checks. If `vn` is not specified,
74+
`(@varname \$expr)` will be used in its place.
6275
"""
6376
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
6477
return quote
@@ -221,9 +234,6 @@ variables.
221234
222235
# Example
223236
```jldoctest; setup=:(using Distributions, LinearAlgebra)
224-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
225-
x[:, 2]
226-
227237
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
228238
x[1, 2]
229239
@@ -241,31 +251,20 @@ end
241251
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)
242252
return unwrap_right_left_vns(right.dist, left, right.name)
243253
end
244-
function unwrap_right_left_vns(
245-
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
246-
)
247-
# This an expression such as `x .~ MvNormal()` which we interpret as
248-
# x[:, i] ~ MvNormal()
249-
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
250-
# and we therefore add the `Colon()` below.
251-
vns = map(axes(left, 2)) do i
252-
return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) vn, left)
253-
end
254-
return unwrap_right_left_vns(right, left, vns)
255-
end
256254
function unwrap_right_left_vns(
257255
right::Union{Distribution,AbstractArray{<:Distribution}},
258256
left::AbstractArray,
259257
vn::VarName,
260258
)
261259
vns = map(CartesianIndices(left)) do i
262-
return Accessors.IndexLens(Tuple(i)) vn
260+
sym, optic = getsym(vn), getoptic(vn)
261+
return VarName{sym}(AbstractPPL.Index(Tuple(i), (;), AbstractPPL.Iden()) optic)
263262
end
264263
return unwrap_right_left_vns(right, left, vns)
265264
end
266265

267266
resolve_varnames(vn::VarName, _) = vn
268-
resolve_varnames(vn::VarName, dist::NamedDist) = dist.name
267+
resolve_varnames(::VarName, dist::NamedDist) = dist.name
269268

270269
#################
271270
# Main Compiler #
@@ -463,9 +462,18 @@ function generate_tilde_literal(left, right)
463462
end
464463
end
465464

466-
assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs))
467-
function assign_or_set!!(lhs::Expr, rhs)
468-
return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs))
465+
assign_or_set!!(lhs::Symbol, rhs, vn) = drop_escape(:($lhs = $rhs))
466+
function assign_or_set!!(lhs::Expr, rhs, vn)
467+
left_top_sym = get_top_level_symbol(lhs)
468+
return drop_escape(
469+
:(
470+
$left_top_sym = $(Accessors.set)(
471+
$left_top_sym,
472+
$(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)),
473+
$rhs,
474+
)
475+
),
476+
)
469477
end
470478

471479
"""
@@ -487,12 +495,13 @@ function generate_tilde(left, right)
487495
$isassumption = $(DynamicPPL.isassumption(left, vn))
488496
if $(DynamicPPL.isfixed(left, vn))
489497
# $left may not be a simple varname, it might be x.a or x[1], in which case we
490-
# need to use BangBang.@set!! to safely set it.
498+
# need to use Accessors.set to safely set it.
491499
$(assign_or_set!!(
492500
left,
493501
:($(DynamicPPL.getfixed_nested)(
494502
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
495503
)),
504+
vn,
496505
))
497506
elseif $isassumption
498507
$(generate_tilde_assume(left, dist, vn))
@@ -520,7 +529,7 @@ function generate_tilde(left, right)
520529
$vn,
521530
__varinfo__,
522531
)
523-
$(assign_or_set!!(left, value))
532+
$(assign_or_set!!(left, value, vn))
524533
$value
525534
end
526535
end
@@ -531,11 +540,17 @@ function generate_tilde_assume(left, right, vn)
531540
# with multiple arguments on the LHS, we need to capture the return-values
532541
# and then update the LHS variables one by one.
533542
@gensym value
534-
expr = :($left = $value)
535-
if left isa Expr
536-
expr = AbstractPPL.drop_escape(
537-
Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true)
543+
expr = if left isa Expr # as opposed to Symbol
544+
left_top_sym = get_top_level_symbol(left)
545+
:(
546+
$left_top_sym = $(Accessors.set)(
547+
$left_top_sym,
548+
$(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)),
549+
$value,
550+
)
538551
)
552+
else
553+
:($left = $value)
539554
end
540555

541556
return quote

src/contexts/init.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ in `DynamicPPL.unflatten!!` for more details. For non-threadsafe evaluation, Jul
6060
capable of automatically promoting the types on its own. Secondly, the promotion only
6161
matters if you are trying to directly assign into a `Vector{Float64}` with a
6262
`ForwardDiff.Dual` or similar tracer type, for example using `xs[i] = MyDual`. This doesn't
63-
actually apply to tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set`
63+
actually apply to tilde-statements like `xs[i] ~ ...` because those use `Accessors.set`
6464
under the hood, which also does the promotion for you. For the gory details, see the
6565
following issues:
6666

src/debug_utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module DebugUtils
33
using ..DynamicPPL
44

55
using Random: Random
6-
using Accessors: Accessors
76
using InteractiveUtils: InteractiveUtils
87

98
using DocStringExtensions

src/model.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
508508
cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
509509
510510
julia> conditioned(cm)
511-
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
511+
Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry:
512512
a.m => 1.0
513513
514514
julia> # Now `a.x` will be sampled.
@@ -833,24 +833,24 @@ julia> # Returns all the variables we have fixed on + their values.
833833
(x = 100.0, m = 1.0)
834834
835835
julia> # The rest of this is the same as the `condition` example above.
836-
cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0);
836+
fm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0);
837837
838-
julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)])
838+
julia> Set(keys(fixed(fm))) == Set([@varname(a.m), @varname(x)])
839839
true
840840
841-
julia> keys(VarInfo(cm))
841+
julia> keys(VarInfo(fm))
842842
1-element Vector{VarName}:
843843
a.x
844844
845-
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
846-
cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
845+
julia> # We can also fix `a.m` _outside_ of the PrefixContext:
846+
fm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
847847
848-
julia> fixed(cm)
849-
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
848+
julia> fixed(fm)
849+
Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry:
850850
a.m => 1.0
851851
852852
julia> # Now `a.x` will be sampled.
853-
keys(VarInfo(cm))
853+
keys(VarInfo(fm))
854854
1-element Vector{VarName}:
855855
a.x
856856
```

src/test_utils/model_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Return a `NamedTuple` compatible with `varnames(model)` where the values represe
104104
the posterior mean under `model`.
105105
106106
"Compatible" means that a `varname` from `varnames(model)` can be used to extract the
107-
corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
107+
corresponding value using e.g. `AbstractPPL.getvalue(posterior_mean(model), varname)`.
108108
"""
109109
function posterior_mean end
110110

src/test_utils/models.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,7 @@ function logprior_true_with_logabsdet_jacobian(
639639
return _demo_logprior_true_with_logabsdet_jacobian(model, s.params[1].subparams, m)
640640
end
641641
function varnames(::Model{typeof(demo_nested_colons)})
642-
return [
643-
@varname(
644-
s.params[1].subparams[
645-
AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(1))),
646-
1,
647-
AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))),
648-
]
649-
),
650-
@varname(m),
651-
]
642+
return [@varname(s.params[1].subparams[:, 1, :]), @varname(m)]
652643
end
653644
function varnames_split(::Model{typeof(demo_nested_colons)})
654645
return [
@@ -887,7 +878,7 @@ const DEMO_MODELS = (
887878
demo_dot_assume_observe_submodel(),
888879
demo_dot_assume_observe_matrix_index(),
889880
demo_assume_matrix_observe_matrix_index(),
890-
demo_nested_colons(),
881+
# demo_nested_colons(),
891882
)
892883

893884
"""

src/test_utils/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in
1010
"""
1111
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
1212
for vn in vns
13-
val = get(vals, vn)
13+
val = AbstractPPL.getvalue(vals, vn)
1414
# TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404
1515
# Remove once the fix is all Julia versions we support.
1616
if val isa Cholesky

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ Return instance similar to `vi` but with `vns` set to values from `vals`.
545545
"""
546546
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
547547
for vn in vns
548-
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
548+
vi = DynamicPPL.setindex!!(vi, AbstractPPL.getvalue(vals, vn), vn)
549549
end
550550
return vi
551551
end

0 commit comments

Comments
 (0)