1
+ using AbstractPPL
2
+ import Base. getindex
3
+ using SparseArrays
4
+ using Setfield
5
+ using Setfield: PropertyLens, get
6
+
7
+ """
8
+ GraphInfo
9
+
10
+ Record the state of the model as a struct of NamedTuples, all
11
+ sharing the same key values, namely, those of the model parameters.
12
+ `value` should store the initial/current value of the parameters.
13
+ `input` stores a tuple of inputs for a given node. `eval` are the
14
+ anonymous functions associated with each node. These might typically
15
+ be either deterministic values or some distribution, but could an
16
+ arbitrary julia program. `kind` is a tuple of symbols indicating
17
+ whether the node is a logical or stochastic node. Additionally, the
18
+ adjacency matrix and topologically ordered vertex list and stored.
19
+
20
+ GraphInfo is instantiated using the `Model` constctor.
21
+ """
22
+
23
+ struct GraphInfo{T} <: AbstractModelTrace
24
+ input:: NamedTuple{T}
25
+ value:: NamedTuple{T}
26
+ eval:: NamedTuple{T}
27
+ kind:: NamedTuple{T}
28
+ A:: SparseMatrixCSC
29
+ sorted_vertices:: Vector{Symbol}
30
+ end
31
+
32
+ """
33
+ Model(;kwargs...)
34
+
35
+ `Model` type constructor that takes in named arguments for
36
+ nodes and returns a `Model`. Nodes are pairs of variable names
37
+ and tuples containing default value, an eval function
38
+ and node type. The inputs of each node are inferred from
39
+ their anonymous functions. The returned object has a type
40
+ GraphInfo{(sorted_vertices...)}.
41
+
42
+ # Examples
43
+ ```jl-doctest
44
+ julia> using AbstractPPL
45
+
46
+ julia> Model(
47
+ s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
48
+ μ = (1.0, () -> 1.0, :Logical),
49
+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
50
+ )
51
+ Nodes:
52
+ μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical)
53
+ s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
54
+ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
55
+ ```
56
+ """
57
+
58
+ struct Model{T} <: AbstractProbabilisticProgram
59
+ g:: GraphInfo{T}
60
+ end
61
+
62
+ function Model (;kwargs... )
63
+ for (i, node) in enumerate (values (kwargs))
64
+ @assert typeof (node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
65
+ end
66
+ vals = getvals (NamedTuple (kwargs))
67
+ args = [argnames (f) for f in vals[2 ]]
68
+ A, sorted_vertices = dag (NamedTuple {keys(kwargs)} (args))
69
+ modelinputs = NamedTuple {Tuple(sorted_vertices)} .([Tuple .(args), vals... ])
70
+ Model (GraphInfo (modelinputs... , A, sorted_vertices))
71
+ end
72
+
73
+ """
74
+ dag(inputs)
75
+
76
+ Function taking in a NamedTuple containing the inputs to each node
77
+ and returns the implied adjacency matrix and topologically ordered
78
+ vertex list.
79
+ """
80
+ function dag (inputs)
81
+ input_names = Symbol[keys (inputs)... ]
82
+ A = adjacency_matrix (inputs)
83
+ sorted_vertices = topological_sort_by_dfs (A)
84
+ sorted_A = permute (A, collect (1 : length (inputs)), sorted_vertices)
85
+ sorted_A, input_names[sorted_vertices]
86
+ end
87
+
88
+ """
89
+ getvals(nt::NamedTuple{T})
90
+
91
+ Takes in the arguments to Model(;kwargs...) as a NamedTuple and
92
+ reorders into a tuple of tuples each containing either of value,
93
+ input, eval and kind, as required by the GraphInfo type.
94
+ """
95
+ @generated function getvals (nt:: NamedTuple{T} ) where T
96
+ values = [:(nt[$ i][$ j]) for i in 1 : length (T), j in 1 : 3 ]
97
+ m = [:($ (values[:,i]. .. ), ) for i in 1 : 3 ]
98
+ return Expr (:tuple , m... ) # :($(m...),)
99
+ end
100
+
101
+ """
102
+ argnames(f::Function)
103
+
104
+ Returns a Vector{Symbol} of the inputs to an anonymous function `f`.
105
+ """
106
+ argnames (f:: Function ) = Base. method_argnames (first (methods (f)))[2 : end ]
107
+
108
+ """
109
+ adjacency_matrix(inputs)
110
+
111
+ For a NamedTuple{T} with vertices `T` paired with tuples of input nodes,
112
+ `adjacency_matrix` constructs the adjacency matrix using the order
113
+ of variables given by `T`.
114
+
115
+ # Examples
116
+ ```jl-doctest
117
+ julia> inputs = (a = (), b = (), c = (:a, :b))
118
+ (a = (), b = (), c = (:a, :b))
119
+
120
+ julia> AbstractPPL.adjacency_matrix(inputs)
121
+ 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
122
+ ⋅ ⋅ ⋅
123
+ ⋅ ⋅ ⋅
124
+ 1.0 1.0 ⋅
125
+ ```
126
+ """
127
+ function adjacency_matrix (inputs:: NamedTuple{nodes} ) where {nodes}
128
+ N = length (inputs)
129
+ col_inds = NamedTuple {nodes} (ntuple (identity, N))
130
+ A = spzeros (Bool, N, N)
131
+ for (row, node) in enumerate (nodes)
132
+ for input in inputs[node]
133
+ if input ∉ nodes
134
+ error (" Parent node of $(input) not found in node set: $(nodes) " )
135
+ end
136
+ col = col_inds[input]
137
+ A[row, col] = true
138
+ end
139
+ end
140
+ return A
141
+ end
142
+
143
+ function outneighbors (A:: SparseMatrixCSC , u:: T ) where T <: Int
144
+ # adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
145
+ inds, _ = findnz (A[:, u])
146
+ inds
147
+ end
148
+
149
+ function topological_sort_by_dfs (A)
150
+ # lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44
151
+ # Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf
152
+ n_verts = size (A)[1 ]
153
+ vcolor = zeros (UInt8, n_verts)
154
+ verts = Vector {Int64} ()
155
+ for v in 1 : n_verts
156
+ vcolor[v] != 0 && continue
157
+ S = Vector {Int64} ([v])
158
+ vcolor[v] = 1
159
+ while ! isempty (S)
160
+ u = S[end ]
161
+ w = 0
162
+ for n in outneighbors (A, u)
163
+ if vcolor[n] == 1
164
+ error (" The input graph contains at least one loop." ) # TODO 0.7 should we use a different error?
165
+ elseif vcolor[n] == 0
166
+ w = n
167
+ break
168
+ end
169
+ end
170
+ if w != 0
171
+ vcolor[w] = 1
172
+ push! (S, w)
173
+ else
174
+ vcolor[u] = 2
175
+ push! (verts, u)
176
+ pop! (S)
177
+ end
178
+ end
179
+ end
180
+ return reverse (verts)
181
+ end
182
+
183
+ """
184
+ Base.getindex(m::Model, vn::VarName{p})
185
+
186
+ Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
187
+ `eval` and `kind` for node `p`.
188
+
189
+ # Examples
190
+
191
+ ```jl-doctest
192
+ julia> using AbstractPPL
193
+
194
+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
195
+ μ = (1.0, () -> 1.0, :Logical),
196
+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
197
+ (s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
198
+ Nodes:
199
+ μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
200
+ s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
201
+ y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
202
+
203
+
204
+ julia> m[@varname y]
205
+ (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
206
+ ```
207
+ """
208
+ @generated function Base. getindex (g:: GraphInfo , vn:: VarName{p} ) where {p}
209
+ fns = fieldnames (GraphInfo)[1 : 4 ]
210
+ name_lens = Setfield. PropertyLens {p} ()
211
+ field_lenses = [Setfield. PropertyLens {f} () for f in fns]
212
+ values = [:(get (g, Setfield. compose ($ l, $ name_lens, getlens (vn)))) for l in field_lenses]
213
+ return :(NamedTuple {$(fns)} (($ (values... ),)))
214
+ end
215
+
216
+ function Base. getindex (m:: Model , vn:: VarName )
217
+ return m. g[vn]
218
+ end
219
+
220
+ function Base. show (io:: IO , m:: Model )
221
+ print (io, " Nodes: \n " )
222
+ for node in nodes (m)
223
+ print (io, " $node = " , m[VarName {node} ()], " \n " )
224
+ end
225
+ end
226
+
227
+
228
+ function Base. iterate (m:: Model , state= 1 )
229
+ state > length (nodes (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
230
+ end
231
+
232
+ Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
233
+ Base. IteratorEltype (m:: Model ) = HasEltype ()
234
+
235
+ Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
236
+ Base. values (m:: Model ) = Base. Generator (identity, m)
237
+ Base. length (m:: Model ) = length (nodes (m))
238
+ Base. keytype (m:: Model ) = eltype (keys (m))
239
+ Base. valtype (m:: Model ) = eltype (m)
240
+
241
+
242
+ """
243
+ dag(m::Model)
244
+
245
+ Returns the adjacency matrix of the model as a SparseArray.
246
+ """
247
+ get_dag (m:: Model ) = m. g. A
248
+
249
+ """
250
+ nodes(m::Model)
251
+
252
+ Returns a `Vector{Symbol}` containing the sorted vertices
253
+ of the DAG.
254
+ """
255
+ nodes (m:: Model ) = m. g. sorted_vertices
0 commit comments