Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using DocStringExtensions
using Random: Random

# For extending
import AbstractPPL: predict
import AbstractPPL: predict, hasvalue, getvalue

# TODO: Remove these when it's possible.
import Bijectors: link, invlink
Expand Down
12 changes: 8 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
Generate a sample of type `T` from the prior distribution of the `model`.
"""
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
x = last(
evaluate_and_sample!!(
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
),
)
return values_as(x, T)
end

Expand Down Expand Up @@ -1032,7 +1036,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down Expand Up @@ -1090,7 +1094,7 @@ julia> logprior(demo_model([1., 2.]), chain);
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down Expand Up @@ -1144,7 +1148,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down
16 changes: 8 additions & 8 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,19 @@ ERROR: type NamedTuple has no field x
[...]

julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict()));
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));

julia> # (✓) Sort of fast, but only possible at runtime.
vi[@varname(x[1])]
-1.019202452456547

julia> # In addtion, we can only access varnames as they appear in the model!
vi[@varname(x)]
ERROR: KeyError: key x not found
ERROR: x was not found in the dictionary provided
[...]

julia> vi[@varname(x[1:2])]
ERROR: KeyError: key x[1:2] not found
ERROR: x[1:2] was not found in the dictionary provided
[...]
```

Expand Down Expand Up @@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
true

julia> # And with `OrderedDict` of course!
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true));
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));

julia> vi[@varname(x)] # (✓) -∞ < x < ∞
0.6225185067787314
Expand Down Expand Up @@ -177,11 +177,11 @@ julia> svi_dict[@varname(m.a[1])]
1.0

julia> svi_dict[@varname(m.a[2])]
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
ERROR: m.a[2] was not found in the dictionary provided
[...]

julia> svi_dict[@varname(m.b)]
ERROR: type NamedTuple has no field b
ERROR: m.b was not found in the dictionary provided
[...]
```
"""
Expand Down Expand Up @@ -212,7 +212,7 @@ end
function SimpleVarInfo(values)
return SimpleVarInfo{LogProbType}(values)
end
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}})
return if isempty(values)
# Can't infer from values, so we just use default.
SimpleVarInfo{LogProbType}(values)
Expand Down Expand Up @@ -264,7 +264,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
end

function untyped_simple_varinfo(model::Model)
varinfo = SimpleVarInfo(OrderedDict())
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
return last(evaluate_and_sample!!(model, varinfo))
end

Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function setup_varinfos(

# SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())
svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}())
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())

varinfos = map((
Expand Down
193 changes: 0 additions & 193 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector)
return D(zip(keys(original), unflatten(collect(values(original)), x)))
end

# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl.
"""
getvalue(vals, vn::VarName)

Return the value(s) in `vals` represented by `vn`.

Note that this method is different from `getindex`. See examples below.

# Examples

For `NamedTuple`:

```jldoctest
julia> vals = (x = [1.0],);

julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]
```

For `AbstractDict`:

```jldoctest
julia> vals = Dict(@varname(x) => [1.0]);

julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]
```

In the `AbstractDict` case we can also have keys such as `v[1]`:

```jldoctest
julia> vals = Dict(@varname(x[1]) => [1.0,]);

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1][2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]

julia> DynamicPPL.getvalue(vals, @varname(x[2][1]))
ERROR: KeyError: key x[2][1] not found
[...]
```
"""
getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn)
getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn)

"""
hasvalue(vals, vn::VarName)

Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref).

# Examples
With `x` as a `NamedTuple`:

```jldoctest
julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x))
true

julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1]))
false

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x))
true

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1]))
true

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2]))
false
```

With `x` as a `AbstractDict`:

```jldoctest
julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1]))
false

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1]))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2]))
false
```

In the `AbstractDict` case we can also have keys such as `v[1]`:

```jldoctest
julia> vals = Dict(@varname(x[1]) => [1.0,]);

julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey`
true

julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey`
true

julia> DynamicPPL.hasvalue(vals, @varname(x[1][2]))
false

julia> DynamicPPL.hasvalue(vals, @varname(x[2][1]))
false
```
"""
function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
# LHS: Ensure that `nt` indeed has the property we want.
# RHS: Ensure that the optic can view into `nt`.
return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym))
end

# For `dictlike` we need to check wether `vn` is "immediately" present, or
# if some ancestor of `vn` is present in `dictlike`.
function hasvalue(vals::AbstractDict, vn::VarName)
# First we check if `vn` is present as is.
haskey(vals, vn) && return true

# If `vn` is not present, we check any parent-varnames by attempting
# to split the optic into the key / `parent` and the extraction optic / `child`.
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent

# Return early if no such split could be found.
issuccess || return false

# At this point we just need to check that we `canview` the value.
value = vals[VarName{getsym(vn)}(keyoptic)]

return canview(child, value)
end

"""
nested_getindex(values::AbstractDict, vn::VarName)

Return value corresponding to `vn` in `values` by also looking
in the the actual values of the dict.
"""
function nested_getindex(values::AbstractDict, vn::VarName)
maybeval = get(values, vn, nothing)
if maybeval !== nothing
return maybeval
end

# Split the optic into the key / `parent` and the extraction optic / `child`.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(values, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent

# If we found a valid split, then we can extract the value.
if !issuccess
# At this point we just throw an error since the key could not be found.
throw(KeyError(vn))
end

# TODO: Should we also check that we `canview` the extracted `value`
# rather than just let it fail upon `get` call?
value = values[VarName{getsym(vn)}(keyoptic)]
return child(value)
end

"""
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)

Expand Down
4 changes: 2 additions & 2 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ $(TYPEDFIELDS)
"""
struct ValuesAsInModelAccumulator <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict
values::OrderedDict{<:VarName}
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
end
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
end

function Base.copy(acc::ValuesAsInModelAccumulator)
Expand Down
2 changes: 1 addition & 1 deletion src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict}
end

# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how
# they differ from `haskey` and `getindex`. They can be found in src/utils.jl.
# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl.

# TODO(mhauru) This is tricky to implement in the general case, and the below implementation
# only covers some simple cases. It's probably sufficient in most situations though.
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
Aqua = "0.8"
Bijectors = "0.15.1"
Expand Down
2 changes: 1 addition & 1 deletion test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
DynamicPPL.TestUtils.DEMO_MODELS
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@testset "$name" for (name, vi) in (
("SVI{Dict}", SimpleVarInfo(Dict())),
("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())),
("SVI{NamedTuple}", SimpleVarInfo(values_constrained)),
("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())),
("TypedVarInfo", DynamicPPL.typed_varinfo(model)),
Expand Down
Loading
Loading