Skip to content

Commit 90f0674

Browse files
PavanChaggaryebaiphipsgabler
authored
Remaking Model consturctor to turn values into Ref(values) (#57)
* Remaking Model consturctor to turn values into Ref(values) * adding setindex! * adding tests for setindex * Update src/graphinfo.jl Co-authored-by: Hong Ge <[email protected]> * adding setvalue! function * Update test/graphinfo.jl Co-authored-by: Philipp Gabler <[email protected]> * Update test/graphinfo.jl Co-authored-by: Philipp Gabler <[email protected]> * changing assertion syntax * fixing bug that the model values are not in the same order as the nodes, i.e. not sorted topologically * getvalue function for getting node values * Update Project.toml * Update Project.toml * Update CI.yml * refactoring test * refactoring tests * adding docstrings * test error * fixing bugs on 1.6.5 Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Philipp Gabler <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent dff966f commit 90f0674

File tree

4 files changed

+131
-42
lines changed

4 files changed

+131
-42
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
matrix:
2121
version:
2222
- '1'
23-
- '1.0'
23+
- '1.6'
2424
- 'nightly'
2525
os:
2626
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
AbstractMCMC = "2, 3, 4"
1616
DensityInterface = "0.4"
1717
Setfield = "0.7.1, 0.8"
18-
julia = "1"
18+
julia = "1.6"

src/graphinfo.jl

Lines changed: 106 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ end
6161

6262
function Model(;kwargs...)
6363
for (i, node) in enumerate(values(kwargs))
64-
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
64+
@assert node isa Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
6565
end
66-
vals = getvals(NamedTuple(kwargs))
66+
node_keys = keys(kwargs)
67+
vals = [getvals(NamedTuple(kwargs))...]
68+
vals[1] = Tuple([Ref(val) for val in vals[1]])
6769
args = [argnames(f) for f in vals[2]]
68-
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args))
69-
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...])
70-
Model(GraphInfo(modelinputs..., A, sorted_vertices))
70+
A, sorted_inds = dag(NamedTuple{node_keys}(args))
71+
sorted_vertices = node_keys[sorted_inds]
72+
model_inputs = NamedTuple{node_keys}.([Tuple.(args), vals...])
73+
sorted_model_inputs = [NamedTuple{sorted_vertices}(m) for m in model_inputs]
74+
Model(GraphInfo(sorted_model_inputs..., A, [sorted_vertices...]))
7175
end
7276

7377
"""
@@ -78,11 +82,10 @@ and returns the implied adjacency matrix and topologically ordered
7882
vertex list.
7983
"""
8084
function dag(inputs)
81-
input_names = Symbol[keys(inputs)...]
8285
A = adjacency_matrix(inputs)
8386
sorted_vertices = topological_sort_by_dfs(A)
8487
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices)
85-
sorted_A, input_names[sorted_vertices]
88+
sorted_A, sorted_vertices
8689
end
8790

8891
"""
@@ -95,7 +98,7 @@ input, eval and kind, as required by the GraphInfo type.
9598
@generated function getvals(nt::NamedTuple{T}) where T
9699
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3]
97100
m = [:($(values[:,i]...), ) for i in 1:3]
98-
return Expr(:tuple, m...) # :($(m...),)
101+
return Expr(:tuple, m...)
99102
end
100103

101104
"""
@@ -180,6 +183,7 @@ function topological_sort_by_dfs(A)
180183
return reverse(verts)
181184
end
182185

186+
# getters and setters
183187
"""
184188
Base.getindex(m::Model, vn::VarName{p})
185189
@@ -217,39 +221,116 @@ function Base.getindex(m::Model, vn::VarName)
217221
return m.g[vn]
218222
end
219223

220-
function Base.show(io::IO, m::Model)
221-
print(io, "Nodes: \n")
222-
for node in nodes(m)
223-
print(io, "$node = ", m[VarName{node}()], "\n")
224-
end
224+
"""
225+
set_node_value!(m::Model, ind::VarName, value::T) where Takes
226+
227+
Change the value of the node.
228+
229+
# Examples
230+
231+
```jl-doctest
232+
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233+
μ = (1.0, () -> 1.0, :Logical),
234+
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
235+
Nodes:
236+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
238+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#39#42"(), kind = :Stochastic)
239+
240+
241+
julia> set_node_value!(m, @varname(s2), 1.0)
242+
1.0
243+
244+
julia> get_node_value(m, @varname s2)
245+
1.0
246+
```
247+
"""
248+
function set_node_value!(m::Model, ind::VarName, value::T) where T
249+
@assert typeof(m[ind].value[]) == T
250+
m[ind].value[] = value
225251
end
226252

253+
"""
254+
get_node_value(m::Model, ind::VarName)
227255
228-
function Base.iterate(m::Model, state=1)
229-
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
256+
Retrieve the value of a particular node, indexed by a VarName.
257+
258+
# Examples
259+
260+
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
261+
μ = (1.0, () -> 1.0, :Logical),
262+
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
263+
Nodes:
264+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#44#47"(), kind = :Logical)
265+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#43#46"(), kind = :Stochastic)
266+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#45#48"(), kind = :Stochastic)
267+
268+
269+
julia> get_node_value(m, @varname s2)
270+
0.0
271+
"""
272+
273+
function get_node_value(m::Model, ind::VarName)
274+
v = getproperty(m[ind], :value)
275+
v[]
230276
end
277+
#Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
231278

232-
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
233-
Base.IteratorEltype(m::Model) = HasEltype()
279+
"""
280+
get_node_input(m::Model, ind::VarName)
234281
235-
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
236-
Base.values(m::Model) = Base.Generator(identity, m)
237-
Base.length(m::Model) = length(nodes(m))
238-
Base.keytype(m::Model) = eltype(keys(m))
239-
Base.valtype(m::Model) = eltype(m)
282+
Retrieve the inputs/parents of a node, as given by model DAG.
283+
"""
284+
get_node_input(m::Model, ind::VarName) = getproperty(m[ind], :input)
240285

286+
"""
287+
get_node_input(m::Model, ind::VarName)
241288
289+
Retrieve the evaluation function for a node.
242290
"""
243-
dag(m::Model)
291+
get_node_eval(m::Model, ind::VarName) = getproperty(m[ind], :eval)
292+
293+
"""
294+
get_nodekind(m::Model, ind::VarName)
295+
296+
Retrieve the type of the node, i.e. stochastic or logical.
297+
"""
298+
get_nodekind(m::Model, ind::VarName) = getproperty(m[ind], :kind)
299+
300+
"""
301+
get_dag(m::Model)
244302
245303
Returns the adjacency matrix of the model as a SparseArray.
246304
"""
247305
get_dag(m::Model) = m.g.A
248306

249307
"""
250-
nodes(m::Model)
308+
get_sorted_vertices(m::Model)
251309
252310
Returns a `Vector{Symbol}` containing the sorted vertices
253311
of the DAG.
254312
"""
255-
nodes(m::Model) = m.g.sorted_vertices
313+
get_sorted_vertices(m::Model) = getproperty(m.g, :sorted_vertices)
314+
315+
# iterators
316+
317+
function Base.iterate(m::Model, state=1)
318+
state > length(get_sorted_vertices(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
319+
end
320+
321+
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
322+
Base.IteratorEltype(m::Model) = HasEltype()
323+
324+
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
325+
Base.values(m::Model) = Base.Generator(identity, m)
326+
Base.length(m::Model) = length(get_sorted_vertices(m))
327+
Base.keytype(m::Model) = eltype(keys(m))
328+
Base.valtype(m::Model) = eltype(m)
329+
330+
# show methods
331+
function Base.show(io::IO, m::Model)
332+
print(io, "Nodes: \n")
333+
for node in get_sorted_vertices(m)
334+
print(io, "$node = ", m[VarName{node}()], "\n")
335+
end
336+
end

test/graphinfo.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using AbstractPPL
2-
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag
2+
import AbstractPPL.GraphPPL:GraphInfo, Model, get_dag, set_node_value!,
3+
get_node_value, get_sorted_vertices, get_node_eval,
4+
get_nodekind, get_node_input
35
using SparseArrays
4-
using Test
6+
57
## Example taken from Mamba
68
line = Dict{Symbol, Any}(
79
:x => [1, 2, 3, 4, 5],
@@ -23,9 +25,10 @@ m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constr
2325

2426
# test the type of the model is correct
2527
@test typeof(m) <: Model
26-
@test typeof(m) == Model{(:s2, :xmat, , , :y)}
28+
sorted_vertices = get_sorted_vertices(m)
29+
@test typeof(m) == Model{Tuple(sorted_vertices)}
2730
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
28-
@test typeof(m.g) == GraphInfo{(:s2, :xmat, , , :y)}
31+
@test typeof(m.g) == GraphInfo{Tuple(sorted_vertices)}
2932

3033
# test the dag is correct
3134
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@@ -35,28 +38,33 @@ A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
3538
@test eltype(m) == valtype(m)
3639

3740
# check the values from the NamedTuple match the values in the fields of GraphInfo
38-
vals = AbstractPPL.GraphPPL.getvals(model)
39-
for (i, field) in enumerate([:value, :eval, :kind])
40-
@test eval( :( values(m.g.$field) == vals[$i] ) )
41+
vals, evals, kinds = AbstractPPL.GraphPPL.getvals(NamedTuple{Tuple(sorted_vertices)}(model))
42+
inputs = (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
43+
44+
for (i, vn) in enumerate(keys(m))
45+
@test vn isa VarName
46+
@test get_node_value(m, vn) == vals[i]
47+
@test get_node_eval(m, vn) == evals[i]
48+
@test get_nodekind(m, vn) == kinds[i]
49+
@test get_node_input(m, vn) == inputs[i]
4150
end
4251

4352
for node in m
4453
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
4554
end
4655

47-
# test the right inputs have been inferred
48-
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
49-
50-
# test keys are VarNames
51-
for key in keys(m)
52-
@test typeof(key) <: VarName
53-
end
54-
5556
# test Model constructor for model with single parent node
5657
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
5758
@test typeof(single_parent_m) == Model{(, :y)}
5859
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
5960

61+
# test setindex
62+
63+
@test_throws AssertionError set_node_value!(m, @varname(s2), [0.0])
64+
@test_throws AssertionError set_node_value!(m, @varname(s2), (1.0,))
65+
set_node_value!(m, @varname(s2), 1.0)
66+
@test get_node_value(m, @varname s2) == 1.0
67+
6068
# test ErrorException for parent node not found
6169
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
6270

0 commit comments

Comments
 (0)