Skip to content

Commit 35af34b

Browse files
committed
Remove Dictionaries with Any key type
1 parent 48cf5b7 commit 35af34b

File tree

9 files changed

+26
-20
lines changed

9 files changed

+26
-20
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using DocStringExtensions
2323
using Random: Random
2424

2525
# For extending
26-
import AbstractPPL: predict
26+
import AbstractPPL: predict, hasvalue, getvalue
2727

2828
# TODO: Remove these when it's possible.
2929
import Bijectors: link, invlink

src/model.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
981981
Generate a sample of type `T` from the prior distribution of the `model`.
982982
"""
983983
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
984-
x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
984+
x = last(
985+
evaluate_and_sample!!(
986+
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
987+
),
988+
)
985989
return values_as(x, T)
986990
end
987991

@@ -1032,7 +1036,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
10321036
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
10331037
var_info = VarInfo(model) # extract variables info from the model
10341038
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1035-
argvals_dict = OrderedDict(
1039+
argvals_dict = OrderedDict{VarName,Any}(
10361040
vn_parent =>
10371041
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10381042
vn_parent in keys(var_info)
@@ -1090,7 +1094,7 @@ julia> logprior(demo_model([1., 2.]), chain);
10901094
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
10911095
var_info = VarInfo(model) # extract variables info from the model
10921096
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1093-
argvals_dict = OrderedDict(
1097+
argvals_dict = OrderedDict{VarName,Any}(
10941098
vn_parent =>
10951099
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10961100
vn_parent in keys(var_info)
@@ -1144,7 +1148,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
11441148
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
11451149
var_info = VarInfo(model) # extract variables info from the model
11461150
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1147-
argvals_dict = OrderedDict(
1151+
argvals_dict = OrderedDict{VarName,Any}(
11481152
vn_parent =>
11491153
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
11501154
vn_parent in keys(var_info)

src/simple_varinfo.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ ERROR: type NamedTuple has no field x
6262
[...]
6363
6464
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict()));
65+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767
julia> # (✓) Sort of fast, but only possible at runtime.
6868
vi[@varname(x[1])]
@@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true));
110+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
@@ -212,7 +212,7 @@ end
212212
function SimpleVarInfo(values)
213213
return SimpleVarInfo{LogProbType}(values)
214214
end
215-
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
215+
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}})
216216
return if isempty(values)
217217
# Can't infer from values, so we just use default.
218218
SimpleVarInfo{LogProbType}(values)
@@ -264,7 +264,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
264264
end
265265

266266
function untyped_simple_varinfo(model::Model)
267-
varinfo = SimpleVarInfo(OrderedDict())
267+
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
268268
return last(evaluate_and_sample!!(model, varinfo))
269269
end
270270

src/test_utils/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function setup_varinfos(
3434

3535
# SimpleVarInfo
3636
svi_typed = SimpleVarInfo(example_values)
37-
svi_untyped = SimpleVarInfo(OrderedDict())
37+
svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}())
3838
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())
3939

4040
varinfos = map((

src/values_as_in_model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ $(TYPEDFIELDS)
1212
"""
1313
struct ValuesAsInModelAccumulator <: AbstractAccumulator
1414
"values that are extracted from the model"
15-
values::OrderedDict
15+
values::OrderedDict{<:VarName}
1616
"whether to extract variables on the LHS of :="
1717
include_colon_eq::Bool
1818
end
1919
function ValuesAsInModelAccumulator(include_colon_eq)
20-
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
20+
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
2121
end
2222

2323
function Base.copy(acc::ValuesAsInModelAccumulator)

src/varnamedvector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict}
14821482
end
14831483

14841484
# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how
1485-
# they differ from `haskey` and `getindex`. They can be found in src/utils.jl.
1485+
# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl.
14861486

14871487
# TODO(mhauru) This is tricky to implement in the general case, and the below implementation
14881488
# only covers some simple cases. It's probably sufficient in most situations though.

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ using LinearAlgebra # Diagonal
2727

2828
using JET: JET
2929

30+
# need to call this to get the AbstractPPL I think
31+
Pkg.update()
32+
3033
using Combinatorics: combinations
3134
using OrderedCollections: OrderedSet
3235

test/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
DynamicPPL.TestUtils.DEMO_MODELS
9191
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
9292
@testset "$name" for (name, vi) in (
93-
("SVI{Dict}", SimpleVarInfo(Dict())),
93+
("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())),
9494
("SVI{NamedTuple}", SimpleVarInfo(values_constrained)),
9595
("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())),
9696
("TypedVarInfo", DynamicPPL.typed_varinfo(model)),

test/varinfo.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ end
110110
test_base(VarInfo())
111111
test_base(DynamicPPL.typed_varinfo(VarInfo()))
112112
test_base(SimpleVarInfo())
113-
test_base(SimpleVarInfo(Dict()))
113+
test_base(SimpleVarInfo(Dict{VarName,Any}()))
114114
test_base(SimpleVarInfo(DynamicPPL.VarNamedVector()))
115115
end
116116

@@ -597,7 +597,7 @@ end
597597
test_linked_varinfo(model, vi)
598598

599599
## `SimpleVarInfo{<:Dict}`
600-
vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true)
600+
vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true)
601601
test_linked_varinfo(model, vi)
602602

603603
## `SimpleVarInfo{<:VarNamedVector}`
@@ -737,11 +737,10 @@ end
737737
model, (; x=1.0), (@varname(x),); include_threadsafe=true
738738
)
739739
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
740-
# Skip the severely inconcrete `SimpleVarInfo` types, since checking for type
740+
# Skip the inconcrete `SimpleVarInfo` types, since checking for type
741741
# stability for them doesn't make much sense anyway.
742-
if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} ||
743-
varinfo isa
744-
DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}}
742+
if varinfo isa SimpleVarInfo{<:AbstractDict} ||
743+
varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}}
745744
continue
746745
end
747746
@inferred DynamicPPL.unflatten(varinfo, varinfo[:])

0 commit comments

Comments
 (0)