Skip to content

Commit ed16f90

Browse files
committed
Fixing type stability issues for constructors (#65)
Fixes #64 by adding more type parameters to the constructors to make the `NamedTuples` concrete.
1 parent 90f0674 commit ed16f90

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

src/graphinfo.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ adjacency matrix and topologically ordered vertex list and stored.
2020
GraphInfo is instantiated using the `Model` constctor.
2121
"""
2222

23-
struct GraphInfo{T} <: AbstractModelTrace
24-
input::NamedTuple{T}
25-
value::NamedTuple{T}
26-
eval::NamedTuple{T}
27-
kind::NamedTuple{T}
23+
struct GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractModelTrace
24+
input::NamedTuple{Tnames, Tinput}
25+
value::NamedTuple{Tnames, Tvalue}
26+
eval::NamedTuple{Tnames, Teval}
27+
kind::NamedTuple{Tnames, Tkind}
2828
A::SparseMatrixCSC
2929
sorted_vertices::Vector{Symbol}
3030
end
@@ -55,8 +55,8 @@ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
5555
```
5656
"""
5757

58-
struct Model{T} <: AbstractProbabilisticProgram
59-
g::GraphInfo{T}
58+
struct Model{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractProbabilisticProgram
59+
g::GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind}
6060
end
6161

6262
function Model(;kwargs...)

test/graphinfo.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ model = (
2424
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor
2525

2626
# test the type of the model is correct
27-
@test typeof(m) <: Model
27+
@test m isa Model
2828
sorted_vertices = get_sorted_vertices(m)
29-
@test typeof(m) == Model{Tuple(sorted_vertices)}
30-
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
31-
@test typeof(m.g) == GraphInfo{Tuple(sorted_vertices)}
29+
@test m isa Model{Tuple(sorted_vertices)}
30+
@test m.g isa GraphInfo <: AbstractModelTrace
31+
@test m.g isa GraphInfo{Tuple(sorted_vertices)}
3232

3333
# test the dag is correct
3434
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@@ -37,11 +37,18 @@ A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
3737
@test length(m) == 5
3838
@test eltype(m) == valtype(m)
3939

40+
4041
# check the values from the NamedTuple match the values in the fields of GraphInfo
4142
vals, evals, kinds = AbstractPPL.GraphPPL.getvals(NamedTuple{Tuple(sorted_vertices)}(model))
4243
inputs = (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
4344

4445
for (i, vn) in enumerate(keys(m))
46+
@inferred m[vn]
47+
@inferred get_node_value(m, vn)
48+
@inferred get_node_eval(m, vn)
49+
@inferred get_nodekind(m, vn)
50+
@inferred get_node_input(m, vn)
51+
4552
@test vn isa VarName
4653
@test get_node_value(m, vn) == vals[i]
4754
@test get_node_eval(m, vn) == evals[i]
@@ -50,16 +57,16 @@ for (i, vn) in enumerate(keys(m))
5057
end
5158

5259
for node in m
53-
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
60+
@test node isa NamedTuple{fieldnames(GraphInfo)[1:4]}
5461
end
5562

5663
# test Model constructor for model with single parent node
5764
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
58-
@test typeof(single_parent_m) == Model{(, :y)}
59-
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
65+
@test single_parent_m isa Model{(, :y)}
66+
@test single_parent_m.g isa GraphInfo{(, :y)}
6067

61-
# test setindex
6268

69+
# test setindex
6370
@test_throws AssertionError set_node_value!(m, @varname(s2), [0.0])
6471
@test_throws AssertionError set_node_value!(m, @varname(s2), (1.0,))
6572
set_node_value!(m, @varname(s2), 1.0)

0 commit comments

Comments
 (0)