61
61
62
62
function Model (;kwargs... )
63
63
for (i, node) in enumerate (values (kwargs))
64
- @assert typeof ( node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
64
+ @assert node isa Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
65
65
end
66
- vals = getvals (NamedTuple (kwargs))
66
+ node_keys = keys (kwargs)
67
+ vals = [getvals (NamedTuple (kwargs))... ]
68
+ vals[1 ] = Tuple ([Ref (val) for val in vals[1 ]])
67
69
args = [argnames (f) for f in vals[2 ]]
68
- A, sorted_vertices = dag (NamedTuple {keys(kwargs)} (args))
69
- modelinputs = NamedTuple {Tuple(sorted_vertices)} .([Tuple .(args), vals... ])
70
- Model (GraphInfo (modelinputs... , A, sorted_vertices))
70
+ A, sorted_inds = dag (NamedTuple {node_keys} (args))
71
+ sorted_vertices = node_keys[sorted_inds]
72
+ model_inputs = NamedTuple {node_keys} .([Tuple .(args), vals... ])
73
+ sorted_model_inputs = [NamedTuple {sorted_vertices} (m) for m in model_inputs]
74
+ Model (GraphInfo (sorted_model_inputs... , A, [sorted_vertices... ]))
71
75
end
72
76
73
77
"""
@@ -78,11 +82,10 @@ and returns the implied adjacency matrix and topologically ordered
78
82
vertex list.
79
83
"""
80
84
function dag (inputs)
81
- input_names = Symbol[keys (inputs)... ]
82
85
A = adjacency_matrix (inputs)
83
86
sorted_vertices = topological_sort_by_dfs (A)
84
87
sorted_A = permute (A, collect (1 : length (inputs)), sorted_vertices)
85
- sorted_A, input_names[ sorted_vertices]
88
+ sorted_A, sorted_vertices
86
89
end
87
90
88
91
"""
@@ -95,7 +98,7 @@ input, eval and kind, as required by the GraphInfo type.
95
98
@generated function getvals (nt:: NamedTuple{T} ) where T
96
99
values = [:(nt[$ i][$ j]) for i in 1 : length (T), j in 1 : 3 ]
97
100
m = [:($ (values[:,i]. .. ), ) for i in 1 : 3 ]
98
- return Expr (:tuple , m... ) # :($(m...),)
101
+ return Expr (:tuple , m... )
99
102
end
100
103
101
104
"""
@@ -180,6 +183,7 @@ function topological_sort_by_dfs(A)
180
183
return reverse (verts)
181
184
end
182
185
186
+ # getters and setters
183
187
"""
184
188
Base.getindex(m::Model, vn::VarName{p})
185
189
@@ -217,39 +221,116 @@ function Base.getindex(m::Model, vn::VarName)
217
221
return m. g[vn]
218
222
end
219
223
220
- function Base. show (io:: IO , m:: Model )
221
- print (io, " Nodes: \n " )
222
- for node in nodes (m)
223
- print (io, " $node = " , m[VarName {node} ()], " \n " )
224
- end
224
+ """
225
+ set_node_value!(m::Model, ind::VarName, value::T) where Takes
226
+
227
+ Change the value of the node.
228
+
229
+ # Examples
230
+
231
+ ```jl-doctest
232
+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233
+ μ = (1.0, () -> 1.0, :Logical),
234
+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
235
+ Nodes:
236
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
238
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#39#42"(), kind = :Stochastic)
239
+
240
+
241
+ julia> set_node_value!(m, @varname(s2), 1.0)
242
+ 1.0
243
+
244
+ julia> get_node_value(m, @varname s2)
245
+ 1.0
246
+ ```
247
+ """
248
+ function set_node_value! (m:: Model , ind:: VarName , value:: T ) where T
249
+ @assert typeof (m[ind]. value[]) == T
250
+ m[ind]. value[] = value
225
251
end
226
252
253
+ """
254
+ get_node_value(m::Model, ind::VarName)
227
255
228
- function Base. iterate (m:: Model , state= 1 )
229
- state > length (nodes (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
256
+ Retrieve the value of a particular node, indexed by a VarName.
257
+
258
+ # Examples
259
+
260
+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
261
+ μ = (1.0, () -> 1.0, :Logical),
262
+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
263
+ Nodes:
264
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#44#47"(), kind = :Logical)
265
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#43#46"(), kind = :Stochastic)
266
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#45#48"(), kind = :Stochastic)
267
+
268
+
269
+ julia> get_node_value(m, @varname s2)
270
+ 0.0
271
+ """
272
+
273
+ function get_node_value (m:: Model , ind:: VarName )
274
+ v = getproperty (m[ind], :value )
275
+ v[]
230
276
end
277
+ # Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
231
278
232
- Base . eltype (m :: Model ) = NamedTuple{ fieldnames (GraphInfo)[ 1 : 4 ]}
233
- Base . IteratorEltype (m:: Model ) = HasEltype ( )
279
+ """
280
+ get_node_input (m::Model, ind::VarName )
234
281
235
- Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
236
- Base. values (m:: Model ) = Base. Generator (identity, m)
237
- Base. length (m:: Model ) = length (nodes (m))
238
- Base. keytype (m:: Model ) = eltype (keys (m))
239
- Base. valtype (m:: Model ) = eltype (m)
282
+ Retrieve the inputs/parents of a node, as given by model DAG.
283
+ """
284
+ get_node_input (m:: Model , ind:: VarName ) = getproperty (m[ind], :input )
240
285
286
+ """
287
+ get_node_input(m::Model, ind::VarName)
241
288
289
+ Retrieve the evaluation function for a node.
242
290
"""
243
- dag(m::Model)
291
+ get_node_eval (m:: Model , ind:: VarName ) = getproperty (m[ind], :eval )
292
+
293
+ """
294
+ get_nodekind(m::Model, ind::VarName)
295
+
296
+ Retrieve the type of the node, i.e. stochastic or logical.
297
+ """
298
+ get_nodekind (m:: Model , ind:: VarName ) = getproperty (m[ind], :kind )
299
+
300
+ """
301
+ get_dag(m::Model)
244
302
245
303
Returns the adjacency matrix of the model as a SparseArray.
246
304
"""
247
305
get_dag (m:: Model ) = m. g. A
248
306
249
307
"""
250
- nodes (m::Model)
308
+ get_sorted_vertices (m::Model)
251
309
252
310
Returns a `Vector{Symbol}` containing the sorted vertices
253
311
of the DAG.
254
312
"""
255
- nodes (m:: Model ) = m. g. sorted_vertices
313
+ get_sorted_vertices (m:: Model ) = getproperty (m. g, :sorted_vertices )
314
+
315
+ # iterators
316
+
317
+ function Base. iterate (m:: Model , state= 1 )
318
+ state > length (get_sorted_vertices (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
319
+ end
320
+
321
+ Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
322
+ Base. IteratorEltype (m:: Model ) = HasEltype ()
323
+
324
+ Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
325
+ Base. values (m:: Model ) = Base. Generator (identity, m)
326
+ Base. length (m:: Model ) = length (get_sorted_vertices (m))
327
+ Base. keytype (m:: Model ) = eltype (keys (m))
328
+ Base. valtype (m:: Model ) = eltype (m)
329
+
330
+ # show methods
331
+ function Base. show (io:: IO , m:: Model )
332
+ print (io, " Nodes: \n " )
333
+ for node in get_sorted_vertices (m)
334
+ print (io, " $node = " , m[VarName {node} ()], " \n " )
335
+ end
336
+ end
0 commit comments