Skip to content

Commit c6c0cbc

Browse files
authored
InitContext, part 2 - Move hasvalue and getvalue to AbstractPPL; enforce key type of AbstractDict (#980)
* point to unmerged AbstractPPL branch * Remove code that was moved to AbstractPPL * Remove Dictionaries with Any key type * Fix bad merge conflict resolution * Fix doctests * Point to [email protected] This reverts commit 709dc9e. * Fix doctests * Fix docs AbstractPPL bound * Remove stray `Pkg.update()`
1 parent 05cd886 commit c6c0cbc

File tree

12 files changed

+30
-220
lines changed

12 files changed

+30
-220
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4747
[compat]
4848
ADTypes = "1"
4949
AbstractMCMC = "5"
50-
AbstractPPL = "0.11, 0.12"
50+
AbstractPPL = "0.13"
5151
Accessors = "0.1"
5252
BangBang = "0.4.1"
5353
Bijectors = "0.13.18, 0.14, 0.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1414
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1515

1616
[compat]
17-
AbstractPPL = "0.11, 0.12"
17+
AbstractPPL = "0.13"
1818
Accessors = "0.1"
1919
DataStructures = "0.18"
2020
Distributions = "0.25"

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using DocStringExtensions
2323
using Random: Random
2424

2525
# For extending
26-
import AbstractPPL: predict
26+
import AbstractPPL: predict, hasvalue, getvalue
2727

2828
# TODO: Remove these when it's possible.
2929
import Bijectors: link, invlink

src/model.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
981981
Generate a sample of type `T` from the prior distribution of the `model`.
982982
"""
983983
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
984-
x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
984+
x = last(
985+
evaluate_and_sample!!(
986+
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
987+
),
988+
)
985989
return values_as(x, T)
986990
end
987991

@@ -1032,7 +1036,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
10321036
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
10331037
var_info = VarInfo(model) # extract variables info from the model
10341038
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1035-
argvals_dict = OrderedDict(
1039+
argvals_dict = OrderedDict{VarName,Any}(
10361040
vn_parent =>
10371041
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10381042
vn_parent in keys(var_info)
@@ -1090,7 +1094,7 @@ julia> logprior(demo_model([1., 2.]), chain);
10901094
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
10911095
var_info = VarInfo(model) # extract variables info from the model
10921096
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1093-
argvals_dict = OrderedDict(
1097+
argvals_dict = OrderedDict{VarName,Any}(
10941098
vn_parent =>
10951099
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10961100
vn_parent in keys(var_info)
@@ -1144,7 +1148,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
11441148
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
11451149
var_info = VarInfo(model) # extract variables info from the model
11461150
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1147-
argvals_dict = OrderedDict(
1151+
argvals_dict = OrderedDict{VarName,Any}(
11481152
vn_parent =>
11491153
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
11501154
vn_parent in keys(var_info)

src/simple_varinfo.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ ERROR: type NamedTuple has no field x
6262
[...]
6363
6464
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict()));
65+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767
julia> # (✓) Sort of fast, but only possible at runtime.
6868
vi[@varname(x[1])]
6969
-1.019202452456547
7070
7171
julia> # In addtion, we can only access varnames as they appear in the model!
7272
vi[@varname(x)]
73-
ERROR: KeyError: key x not found
73+
ERROR: x was not found in the dictionary provided
7474
[...]
7575
7676
julia> vi[@varname(x[1:2])]
77-
ERROR: KeyError: key x[1:2] not found
77+
ERROR: x[1:2] was not found in the dictionary provided
7878
[...]
7979
```
8080
@@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true));
110+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
@@ -177,11 +177,11 @@ julia> svi_dict[@varname(m.a[1])]
177177
1.0
178178
179179
julia> svi_dict[@varname(m.a[2])]
180-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
180+
ERROR: m.a[2] was not found in the dictionary provided
181181
[...]
182182
183183
julia> svi_dict[@varname(m.b)]
184-
ERROR: type NamedTuple has no field b
184+
ERROR: m.b was not found in the dictionary provided
185185
[...]
186186
```
187187
"""
@@ -212,7 +212,7 @@ end
212212
function SimpleVarInfo(values)
213213
return SimpleVarInfo{LogProbType}(values)
214214
end
215-
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
215+
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}})
216216
return if isempty(values)
217217
# Can't infer from values, so we just use default.
218218
SimpleVarInfo{LogProbType}(values)
@@ -264,7 +264,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
264264
end
265265

266266
function untyped_simple_varinfo(model::Model)
267-
varinfo = SimpleVarInfo(OrderedDict())
267+
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
268268
return last(evaluate_and_sample!!(model, varinfo))
269269
end
270270

src/test_utils/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function setup_varinfos(
3434

3535
# SimpleVarInfo
3636
svi_typed = SimpleVarInfo(example_values)
37-
svi_untyped = SimpleVarInfo(OrderedDict())
37+
svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}())
3838
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())
3939

4040
varinfos = map((

src/utils.jl

Lines changed: 0 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -751,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector)
751751
return D(zip(keys(original), unflatten(collect(values(original)), x)))
752752
end
753753

754-
# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl.
755-
"""
756-
getvalue(vals, vn::VarName)
757-
758-
Return the value(s) in `vals` represented by `vn`.
759-
760-
Note that this method is different from `getindex`. See examples below.
761-
762-
# Examples
763-
764-
For `NamedTuple`:
765-
766-
```jldoctest
767-
julia> vals = (x = [1.0],);
768-
769-
julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
770-
1-element Vector{Float64}:
771-
1.0
772-
773-
julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
774-
1.0
775-
776-
julia> DynamicPPL.getvalue(vals, @varname(x[2]))
777-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
778-
[...]
779-
```
780-
781-
For `AbstractDict`:
782-
783-
```jldoctest
784-
julia> vals = Dict(@varname(x) => [1.0]);
785-
786-
julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
787-
1-element Vector{Float64}:
788-
1.0
789-
790-
julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
791-
1.0
792-
793-
julia> DynamicPPL.getvalue(vals, @varname(x[2]))
794-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
795-
[...]
796-
```
797-
798-
In the `AbstractDict` case we can also have keys such as `v[1]`:
799-
800-
```jldoctest
801-
julia> vals = Dict(@varname(x[1]) => [1.0,]);
802-
803-
julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex`
804-
1-element Vector{Float64}:
805-
1.0
806-
807-
julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex`
808-
1.0
809-
810-
julia> DynamicPPL.getvalue(vals, @varname(x[1][2]))
811-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
812-
[...]
813-
814-
julia> DynamicPPL.getvalue(vals, @varname(x[2][1]))
815-
ERROR: KeyError: key x[2][1] not found
816-
[...]
817-
```
818-
"""
819-
getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn)
820-
getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn)
821-
822-
"""
823-
hasvalue(vals, vn::VarName)
824-
825-
Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref).
826-
827-
# Examples
828-
With `x` as a `NamedTuple`:
829-
830-
```jldoctest
831-
julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x))
832-
true
833-
834-
julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1]))
835-
false
836-
837-
julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x))
838-
true
839-
840-
julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1]))
841-
true
842-
843-
julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2]))
844-
false
845-
```
846-
847-
With `x` as a `AbstractDict`:
848-
849-
```jldoctest
850-
julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x))
851-
true
852-
853-
julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1]))
854-
false
855-
856-
julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x))
857-
true
858-
859-
julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1]))
860-
true
861-
862-
julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2]))
863-
false
864-
```
865-
866-
In the `AbstractDict` case we can also have keys such as `v[1]`:
867-
868-
```jldoctest
869-
julia> vals = Dict(@varname(x[1]) => [1.0,]);
870-
871-
julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey`
872-
true
873-
874-
julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey`
875-
true
876-
877-
julia> DynamicPPL.hasvalue(vals, @varname(x[1][2]))
878-
false
879-
880-
julia> DynamicPPL.hasvalue(vals, @varname(x[2][1]))
881-
false
882-
```
883-
"""
884-
function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
885-
# LHS: Ensure that `nt` indeed has the property we want.
886-
# RHS: Ensure that the optic can view into `nt`.
887-
return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym))
888-
end
889-
890-
# For `dictlike` we need to check wether `vn` is "immediately" present, or
891-
# if some ancestor of `vn` is present in `dictlike`.
892-
function hasvalue(vals::AbstractDict, vn::VarName)
893-
# First we check if `vn` is present as is.
894-
haskey(vals, vn) && return true
895-
896-
# If `vn` is not present, we check any parent-varnames by attempting
897-
# to split the optic into the key / `parent` and the extraction optic / `child`.
898-
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
899-
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
900-
o = optic === nothing ? identity : optic
901-
haskey(vals, VarName{getsym(vn)}(o))
902-
end
903-
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
904-
keyoptic = parent === nothing ? identity : parent
905-
906-
# Return early if no such split could be found.
907-
issuccess || return false
908-
909-
# At this point we just need to check that we `canview` the value.
910-
value = vals[VarName{getsym(vn)}(keyoptic)]
911-
912-
return canview(child, value)
913-
end
914-
915-
"""
916-
nested_getindex(values::AbstractDict, vn::VarName)
917-
918-
Return value corresponding to `vn` in `values` by also looking
919-
in the the actual values of the dict.
920-
"""
921-
function nested_getindex(values::AbstractDict, vn::VarName)
922-
maybeval = get(values, vn, nothing)
923-
if maybeval !== nothing
924-
return maybeval
925-
end
926-
927-
# Split the optic into the key / `parent` and the extraction optic / `child`.
928-
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
929-
o = optic === nothing ? identity : optic
930-
haskey(values, VarName{getsym(vn)}(o))
931-
end
932-
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
933-
keyoptic = parent === nothing ? identity : parent
934-
935-
# If we found a valid split, then we can extract the value.
936-
if !issuccess
937-
# At this point we just throw an error since the key could not be found.
938-
throw(KeyError(vn))
939-
end
940-
941-
# TODO: Should we also check that we `canview` the extracted `value`
942-
# rather than just let it fail upon `get` call?
943-
value = values[VarName{getsym(vn)}(keyoptic)]
944-
return child(value)
945-
end
946-
947754
"""
948755
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
949756

src/values_as_in_model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ $(TYPEDFIELDS)
1212
"""
1313
struct ValuesAsInModelAccumulator <: AbstractAccumulator
1414
"values that are extracted from the model"
15-
values::OrderedDict
15+
values::OrderedDict{<:VarName}
1616
"whether to extract variables on the LHS of :="
1717
include_colon_eq::Bool
1818
end
1919
function ValuesAsInModelAccumulator(include_colon_eq)
20-
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
20+
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
2121
end
2222

2323
function Base.copy(acc::ValuesAsInModelAccumulator)

src/varnamedvector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict}
14821482
end
14831483

14841484
# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how
1485-
# they differ from `haskey` and `getindex`. They can be found in src/utils.jl.
1485+
# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl.
14861486

14871487
# TODO(mhauru) This is tricky to implement in the general case, and the below implementation
14881488
# only covers some simple cases. It's probably sufficient in most situations though.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3030
[compat]
3131
ADTypes = "1"
3232
AbstractMCMC = "5"
33-
AbstractPPL = "0.11, 0.12"
33+
AbstractPPL = "0.13"
3434
Accessors = "0.1"
3535
Aqua = "0.8"
3636
Bijectors = "0.15.1"

0 commit comments

Comments
 (0)