Skip to content

Commit 4692526

Browse files
committed
DAG Model interface (#47)
This is a draft PR introducing a `Model` type that stores and makes use the model graph. The main type introduced here is the `Model` struct which stores the `ModelState` and `DAG`, each of which are their own types. `ModelState` contains information about the node values, dependencies and eval functions and `DAG` contains the graph and topologically ordered vertex list. A model can be constructed in the following way: ```julia julia> nt = ( s2 = (0.0, (), () -> InverseGamma(2.0,3.0), :Stochastic), μ = (1.0, (), () -> 1.0, :Logical), y = (0.0, (:μ, :s2), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) ) (s2 = (0.0, (), var"#33#36"(), :Stochastic), μ = (1.0, (), var"#34#37"(), :Logical), y = (0.0, (:μ, :s2), var"#35#38"(), :Stochastic)) julia> Model(nt) Nodes: μ = (value = 1.0, input = (), eval = var"#16#19"(), kind = :Logical) s2 = (value = 0.0, input = (), eval = var"#15#18"(), kind = :Stochastic) y = (value = 0.0, input = (:μ, :s2), eval = var"#17#20"(), kind = :Stochastic) DAG: 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1.0 1.0 ⋅ ``` At present, only functions needed for the constructors are implemented, as well as indexing using `@varname`. I still need to complete the integration with the AbstractPPL api. TODO: ~~- [ ] `condition`/`decondition`,~~ ~~- [ ] `sample`~~ ~~- [ ] `logdensityof`~~ - [x] pure functions for ordered dictionary, as outlined in [AbstractPPL](https://github.com/TuringLang/AbstractPPL.jl#property-interface) Feedback on `Model` structure welcome whilst I implement the remaining features!
1 parent 9b64dd8 commit 4692526

File tree

7 files changed

+332
-5
lines changed

7 files changed

+332
-5
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ docs/site/
2222
# committed for packages, but should be committed for applications that require a static
2323
# environment.
2424
Manifest.toml
25+
26+
# vs code environment
27+
.vscode

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.4"
6+
version = "0.4.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
1010
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1111
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1213

1314
[compat]
1415
AbstractMCMC = "2, 3"

src/AbstractPPL.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ export VarName, getsym, getlens, inspace, subsumes, varname, vsym, @varname, @vs
77
# Abstract model functions
88
export AbstractProbabilisticProgram, condition, decondition, logdensityof, densityof
99

10-
1110
# Abstract traces
1211
export AbstractModelTrace
1312

@@ -17,4 +16,9 @@ include("abstractmodeltrace.jl")
1716
include("abstractprobprog.jl")
1817
include("deprecations.jl")
1918

19+
# GraphInfo
20+
module GraphPPL
21+
include("graphinfo.jl")
22+
end
23+
2024
end # module

src/graphinfo.jl

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
using AbstractPPL
2+
import Base.getindex
3+
using SparseArrays
4+
using Setfield
5+
using Setfield: PropertyLens, get
6+
7+
"""
8+
GraphInfo
9+
10+
Record the state of the model as a struct of NamedTuples, all
11+
sharing the same key values, namely, those of the model parameters.
12+
`value` should store the initial/current value of the parameters.
13+
`input` stores a tuple of inputs for a given node. `eval` are the
14+
anonymous functions associated with each node. These might typically
15+
be either deterministic values or some distribution, but could an
16+
arbitrary julia program. `kind` is a tuple of symbols indicating
17+
whether the node is a logical or stochastic node. Additionally, the
18+
adjacency matrix and topologically ordered vertex list and stored.
19+
20+
GraphInfo is instantiated using the `Model` constctor.
21+
"""
22+
23+
struct GraphInfo{T} <: AbstractModelTrace
24+
input::NamedTuple{T}
25+
value::NamedTuple{T}
26+
eval::NamedTuple{T}
27+
kind::NamedTuple{T}
28+
A::SparseMatrixCSC
29+
sorted_vertices::Vector{Symbol}
30+
end
31+
32+
"""
33+
Model(;kwargs...)
34+
35+
`Model` type constructor that takes in named arguments for
36+
nodes and returns a `Model`. Nodes are pairs of variable names
37+
and tuples containing default value, an eval function
38+
and node type. The inputs of each node are inferred from
39+
their anonymous functions. The returned object has a type
40+
GraphInfo{(sorted_vertices...)}.
41+
42+
# Examples
43+
```jl-doctest
44+
julia> using AbstractPPL
45+
46+
julia> Model(
47+
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
48+
μ = (1.0, () -> 1.0, :Logical),
49+
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
50+
)
51+
Nodes:
52+
μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical)
53+
s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
54+
y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
55+
```
56+
"""
57+
58+
struct Model{T} <: AbstractProbabilisticProgram
59+
g::GraphInfo{T}
60+
end
61+
62+
function Model(;kwargs...)
63+
for (i, node) in enumerate(values(kwargs))
64+
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
65+
end
66+
vals = getvals(NamedTuple(kwargs))
67+
args = [argnames(f) for f in vals[2]]
68+
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args))
69+
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...])
70+
Model(GraphInfo(modelinputs..., A, sorted_vertices))
71+
end
72+
73+
"""
74+
dag(inputs)
75+
76+
Function taking in a NamedTuple containing the inputs to each node
77+
and returns the implied adjacency matrix and topologically ordered
78+
vertex list.
79+
"""
80+
function dag(inputs)
81+
input_names = Symbol[keys(inputs)...]
82+
A = adjacency_matrix(inputs)
83+
sorted_vertices = topological_sort_by_dfs(A)
84+
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices)
85+
sorted_A, input_names[sorted_vertices]
86+
end
87+
88+
"""
89+
getvals(nt::NamedTuple{T})
90+
91+
Takes in the arguments to Model(;kwargs...) as a NamedTuple and
92+
reorders into a tuple of tuples each containing either of value,
93+
input, eval and kind, as required by the GraphInfo type.
94+
"""
95+
@generated function getvals(nt::NamedTuple{T}) where T
96+
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3]
97+
m = [:($(values[:,i]...), ) for i in 1:3]
98+
return Expr(:tuple, m...) # :($(m...),)
99+
end
100+
101+
"""
102+
argnames(f::Function)
103+
104+
Returns a Vector{Symbol} of the inputs to an anonymous function `f`.
105+
"""
106+
argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end]
107+
108+
"""
109+
adjacency_matrix(inputs)
110+
111+
For a NamedTuple{T} with vertices `T` paired with tuples of input nodes,
112+
`adjacency_matrix` constructs the adjacency matrix using the order
113+
of variables given by `T`.
114+
115+
# Examples
116+
```jl-doctest
117+
julia> inputs = (a = (), b = (), c = (:a, :b))
118+
(a = (), b = (), c = (:a, :b))
119+
120+
julia> AbstractPPL.adjacency_matrix(inputs)
121+
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
122+
⋅ ⋅ ⋅
123+
⋅ ⋅ ⋅
124+
1.0 1.0 ⋅
125+
```
126+
"""
127+
function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
128+
N = length(inputs)
129+
col_inds = NamedTuple{nodes}(ntuple(identity, N))
130+
A = spzeros(Bool, N, N)
131+
for (row, node) in enumerate(nodes)
132+
for input in inputs[node]
133+
if input nodes
134+
error("Parent node of $(input) not found in node set: $(nodes)")
135+
end
136+
col = col_inds[input]
137+
A[row, col] = true
138+
end
139+
end
140+
return A
141+
end
142+
143+
function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int
144+
#adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
145+
inds, _ = findnz(A[:, u])
146+
inds
147+
end
148+
149+
function topological_sort_by_dfs(A)
150+
# lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44
151+
# Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf
152+
n_verts = size(A)[1]
153+
vcolor = zeros(UInt8, n_verts)
154+
verts = Vector{Int64}()
155+
for v in 1:n_verts
156+
vcolor[v] != 0 && continue
157+
S = Vector{Int64}([v])
158+
vcolor[v] = 1
159+
while !isempty(S)
160+
u = S[end]
161+
w = 0
162+
for n in outneighbors(A, u)
163+
if vcolor[n] == 1
164+
error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error?
165+
elseif vcolor[n] == 0
166+
w = n
167+
break
168+
end
169+
end
170+
if w != 0
171+
vcolor[w] = 1
172+
push!(S, w)
173+
else
174+
vcolor[u] = 2
175+
push!(verts, u)
176+
pop!(S)
177+
end
178+
end
179+
end
180+
return reverse(verts)
181+
end
182+
183+
"""
184+
Base.getindex(m::Model, vn::VarName{p})
185+
186+
Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
187+
`eval` and `kind` for node `p`.
188+
189+
# Examples
190+
191+
```jl-doctest
192+
julia> using AbstractPPL
193+
194+
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
195+
μ = (1.0, () -> 1.0, :Logical),
196+
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
197+
(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
198+
Nodes:
199+
μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
200+
s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
201+
y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
202+
203+
204+
julia> m[@varname y]
205+
(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
206+
```
207+
"""
208+
@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p}
209+
fns = fieldnames(GraphInfo)[1:4]
210+
name_lens = Setfield.PropertyLens{p}()
211+
field_lenses = [Setfield.PropertyLens{f}() for f in fns]
212+
values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses]
213+
return :(NamedTuple{$(fns)}(($(values...),)))
214+
end
215+
216+
function Base.getindex(m::Model, vn::VarName)
217+
return m.g[vn]
218+
end
219+
220+
function Base.show(io::IO, m::Model)
221+
print(io, "Nodes: \n")
222+
for node in nodes(m)
223+
print(io, "$node = ", m[VarName{node}()], "\n")
224+
end
225+
end
226+
227+
228+
function Base.iterate(m::Model, state=1)
229+
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
230+
end
231+
232+
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
233+
Base.IteratorEltype(m::Model) = HasEltype()
234+
235+
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
236+
Base.values(m::Model) = Base.Generator(identity, m)
237+
Base.length(m::Model) = length(nodes(m))
238+
Base.keytype(m::Model) = eltype(keys(m))
239+
Base.valtype(m::Model) = eltype(m)
240+
241+
242+
"""
243+
dag(m::Model)
244+
245+
Returns the adjacency matrix of the model as a SparseArray.
246+
"""
247+
get_dag(m::Model) = m.g.A
248+
249+
"""
250+
nodes(m::Model)
251+
252+
Returns a `Vector{Symbol}` containing the sorted vertices
253+
of the DAG.
254+
"""
255+
nodes(m::Model) = m.g.sorted_vertices

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
4+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
45
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
56

67
[compat]

test/graphinfo.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using AbstractPPL
2+
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag
3+
using SparseArrays
4+
using Test
5+
## Example taken from Mamba
6+
line = Dict{Symbol, Any}(
7+
:x => [1, 2, 3, 4, 5],
8+
:y => [1, 3, 3, 3, 5]
9+
)
10+
line[:xmat] = [ones(5) line[:x]]
11+
12+
# just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...).
13+
model = (
14+
β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic),
15+
xmat = (line[:xmat], () -> line[:xmat], :Logical),
16+
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
17+
μ = (zeros(5), (xmat, β) -> xmat * β, :Logical),
18+
y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
19+
)
20+
21+
# construct the model!
22+
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor
23+
24+
# test the type of the model is correct
25+
@test typeof(m) <: Model
26+
@test typeof(m) == Model{(:s2, :xmat, , , :y)}
27+
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
28+
@test typeof(m.g) == GraphInfo{(:s2, :xmat, , , :y)}
29+
30+
# test the dag is correct
31+
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
32+
@test get_dag(m) == A
33+
34+
@test length(m) == 5
35+
@test eltype(m) == valtype(m)
36+
37+
# check the values from the NamedTuple match the values in the fields of GraphInfo
38+
vals = AbstractPPL.GraphPPL.getvals(model)
39+
for (i, field) in enumerate([:value, :eval, :kind])
40+
@test eval( :( values(m.g.$field) == vals[$i] ) )
41+
end
42+
43+
for node in m
44+
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
45+
end
46+
47+
# test the right inputs have been inferred
48+
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
49+
50+
# test keys are VarNames
51+
for key in keys(m)
52+
@test typeof(key) <: VarName
53+
end
54+
55+
# test Model constructor for model with single parent node
56+
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
57+
@test typeof(single_parent_m) == Model{(, :y)}
58+
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
59+
60+
# test ErrorException for parent node not found
61+
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
62+
63+
# test AssertionError thrown for kwargs with the wrong order of inputs
64+
@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))

test/runtests.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Test
1212

1313
@testset "AbstractPPL.jl" begin
1414
include("deprecations.jl")
15-
15+
include("graphinfo.jl")
1616
@testset "doctests" begin
1717
DocMeta.setdocmeta!(
1818
AbstractPPL,
@@ -22,5 +22,4 @@ using Test
2222
)
2323
doctest(AbstractPPL; manual=false)
2424
end
25-
end
26-
25+
end

0 commit comments

Comments
 (0)