@@ -3,6 +3,8 @@ import Base.getindex
3
3
using SparseArrays
4
4
using Setfield
5
5
using Setfield: PropertyLens, get
6
+ using DensityInterface
7
+ using Random
6
8
7
9
"""
8
10
GraphInfo
@@ -222,7 +224,7 @@ function Base.getindex(m::Model, vn::VarName)
222
224
end
223
225
224
226
"""
225
- set_node_value!(m::Model, ind::VarName, value::T) where Takes
227
+ set_node_value!(m::Model, ind::VarName, value::T) where T
226
228
227
229
Change the value of the node.
228
230
@@ -231,7 +233,7 @@ Change the value of the node.
231
233
```jl-doctest
232
234
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233
235
μ = (1.0, () -> 1.0, :Logical),
234
- y = (0.0, (μ, s2) -> MvNormal (μ, sqrt(s2)), :Stochastic))
236
+ y = (0.0, (μ, s2) -> Normal (μ, sqrt(s2)), :Stochastic))
235
237
Nodes:
236
238
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237
239
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
@@ -271,31 +273,58 @@ julia> get_node_value(m, @varname s2)
271
273
"""
272
274
273
275
function get_node_value (m:: Model , ind:: VarName )
274
- v = getproperty (m[ind], : value )
276
+ v = get (m[ind], @lens _ . value)
275
277
v[]
276
278
end
277
- # Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
279
+
280
+ function get_node_value (m:: Model , ind:: NTuple{N, Symbol} ) where N
281
+ # [get_node_value(m, VarName{S}()) for S in ind]
282
+ values = Vector {Union{Float64, Array{Float64}}} ()
283
+ for i in ind
284
+ push! (values, get_node_value (m, VarName {i} ()))
285
+ end
286
+ values
287
+ end
288
+
289
+ """
290
+ get_node_ref_value(m::Model, ind::VarName)
291
+ get_node_ref_value(m::Model, ind::NTuple{N, Symbol})
292
+
293
+ Return the mutable Ref value associated with a node or tuple
294
+ of nodes.
295
+ """
296
+ function get_node_ref_value (m:: Model , ind:: VarName )
297
+ get (m[ind], @lens _. value)
298
+ end
299
+
300
+ function get_node_ref_value (m:: Model , ind:: NTuple{N, Symbol} ) where N
301
+ values = Vector{Union{Base. RefValue{Float64}, Base. RefValue{Vector{Float64}}}}()
302
+ for i in ind
303
+ push! (values, get_node_ref_value (m, VarName {i} ()))
304
+ end
305
+ values
306
+ end
278
307
279
308
"""
280
309
get_node_input(m::Model, ind::VarName)
281
310
282
311
Retrieve the inputs/parents of a node, as given by model DAG.
283
312
"""
284
- get_node_input (m:: Model , ind:: VarName ) = getproperty (m[ind], : input )
313
+ get_node_input (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . input)
285
314
286
315
"""
287
316
get_node_input(m::Model, ind::VarName)
288
317
289
318
Retrieve the evaluation function for a node.
290
319
"""
291
- get_node_eval (m:: Model , ind:: VarName ) = getproperty (m[ind], : eval )
320
+ get_node_eval (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . eval)
292
321
293
322
"""
294
323
get_nodekind(m::Model, ind::VarName)
295
324
296
325
Retrieve the type of the node, i.e. stochastic or logical.
297
326
"""
298
- get_nodekind (m:: Model , ind:: VarName ) = getproperty (m[ind], : kind )
327
+ get_nodekind (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . kind)
299
328
300
329
"""
301
330
get_dag(m::Model)
@@ -310,16 +339,48 @@ get_dag(m::Model) = m.g.A
310
339
Returns a `Vector{Symbol}` containing the sorted vertices
311
340
of the DAG.
312
341
"""
313
- get_sorted_vertices (m:: Model ) = getproperty (m. g, :sorted_vertices )
342
+ get_sorted_vertices (m:: Model ) = get (m. g, @lens _. sorted_vertices)
343
+
344
+
345
+ """
346
+ get_model_values(m::Model)
347
+
348
+ Returns a Named Tuple of nodes and node values.
349
+ """
350
+ function get_model_values (m:: Model{T} ) where T
351
+ NamedTuple {T} (get_node_value (m, T))
352
+ end
353
+
354
+ """
355
+ get_model_ref_values(m::Model)
356
+
357
+ Returns a Named Tuple of nodes and node Ref values.
358
+ """
359
+ function get_model_ref_values (m:: Model{T} ) where T
360
+ NamedTuple {T} (get_node_ref_value (m, T))
361
+ end
362
+
363
+ """
364
+ set_model_values!(m::Model, values::NamedTuple)
314
365
366
+ Changes the values of the `Model` node values to those
367
+ given by a Named Tuple of node symboles and new values.
368
+ """
369
+ function set_model_values! (m:: Model{T} , values:: NamedTuple{T} ) where T
370
+ for vn in keys (m)
371
+ if get_nodekind (m, vn) != :Observations
372
+ set_node_value! (m, vn, get (values, vn))
373
+ end
374
+ end
375
+ end
315
376
# iterators
316
377
317
378
function Base. iterate (m:: Model , state= 1 )
318
379
state > length (get_sorted_vertices (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
319
380
end
320
381
321
382
Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
322
- Base. IteratorEltype (m:: Model ) = HasEltype ()
383
+ Base. IteratorEltype (m:: Model ) = Base . HasEltype ()
323
384
324
385
Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
325
386
Base. values (m:: Model ) = Base. Generator (identity, m)
@@ -333,4 +394,156 @@ function Base.show(io::IO, m::Model)
333
394
for node in get_sorted_vertices (m)
334
395
print (io, " $node = " , m[VarName {node} ()], " \n " )
335
396
end
397
+ end
398
+
399
+ """
400
+ rand!(rng::AbstractRNG, m::Model)
401
+
402
+ Draw random samples from the model and mutate the node values.
403
+
404
+ # Examples
405
+
406
+ ```jl-doctest
407
+ julia> import AbstractPPL.GraphPPL: Model, rand!
408
+ using Distributions
409
+
410
+ julia> using Random; Random.seed!(1234)
411
+ TaskLocalRNG()
412
+
413
+ julia> m = Model(s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
414
+ μ = (1.0, () -> 1.0, :Logical),
415
+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
416
+ Nodes:
417
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
418
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
419
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
420
+
421
+
422
+ julia> rand!(m)
423
+ Nodes:
424
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
425
+ s2 = (input = (), value = Base.RefValue{Float64}(2.7478186975593846), eval = var"#5#8"(), kind = :Stochastic)
426
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.3044653509044275), eval = var"#7#10"(), kind = :Stochastic)
427
+ ```
428
+ """
429
+ function Random. rand! (rng:: AbstractRNG , m:: AbstractPPL.GraphPPL.Model{T} ) where T
430
+ for vn in keys (m)
431
+ input, _, f, kind = m[vn]
432
+ input_values = get_node_value (m, input)
433
+ if kind == :Stochastic || kind == :Observations
434
+ set_node_value! (m, vn, rand (rng, f (input_values... )))
435
+ else
436
+ set_node_value! (m, vn, f (input_values... ))
437
+ end
438
+ end
439
+ m
440
+ end
441
+
442
+ function Random. rand! (m:: AbstractPPL.GraphPPL.Model{T} ) where T
443
+ rand! (Random. GLOBAL_RNG, m)
444
+ end
445
+
446
+ """
447
+ rand!(rng::AbstractRNG, m::Model)
448
+
449
+ Draw random samples from the model and mutate the node values.
450
+
451
+ # Examples
452
+
453
+ ```jl-doctest
454
+ julia> using Random; Random.seed!(1234)
455
+
456
+ julia> import AbstractPPL.GraphPPL: Model, rand
457
+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
458
+
459
+ julia> using Distributions
460
+
461
+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
462
+ μ = (0.0, () -> 1.0, :Logical),
463
+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
464
+ Nodes:
465
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
466
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
467
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
468
+
469
+ julia> rand(m)
470
+ (μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368)
471
+ ```
472
+ """
473
+ function Random. rand (rng:: AbstractRNG , sm:: Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}} ) where {Tnames, Tinput, Tvalue, Teval, Tkind}
474
+ m = deepcopy (sm[])
475
+ get_model_values (rand! (rng, m))
476
+ end
477
+
478
+ """
479
+ logdensityof(m::Model)
480
+
481
+ Evaluate the log-densinty of the model.
482
+
483
+ # Examples
484
+
485
+ ```jl-doctest
486
+ julia> using Random; Random.seed!(1234)
487
+ MersenneTwister(1234)
488
+
489
+ julia> import AbstractPPL.GraphPPL: Model, logdensityof
490
+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
491
+
492
+ julia> using Distributions
493
+
494
+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
495
+ μ = (0.0, () -> 1.0, :Logical),
496
+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
497
+ Nodes:
498
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
499
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
500
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
501
+
502
+ julia> logdensityof(m)
503
+ -1.721713955868453
504
+ ```
505
+ """
506
+ function DensityInterface. logdensityof (m:: AbstractPPL.GraphPPL.Model )
507
+ logdensityof (m, get_model_values (m))
508
+ end
509
+
510
+ """
511
+ logdensityof(m::Model{T}, v::NamedTuple{T})
512
+
513
+ Evaluate the log-densinty of the model.
514
+
515
+ # Examples
516
+
517
+ ```jl-doctest
518
+ julia> using Random; Random.seed!(1234)
519
+ MersenneTwister(1234)
520
+
521
+ julia> import AbstractPPL.GraphPPL: Model, logdensityof, get_model_values
522
+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
523
+
524
+ julia> using Distributions
525
+
526
+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
527
+ μ = (0.0, () -> 1.0, :Logical),
528
+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
529
+ Nodes:
530
+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
531
+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
532
+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
533
+
534
+ julia> logdensityof(m, get_model_values(m))
535
+ -1.721713955868453
536
+ """
537
+ function DensityInterface. logdensityof (m:: AbstractPPL.GraphPPL.Model{T} , v:: NamedTuple{T, V} ) where {T, V}
538
+ lp = 0.0
539
+ for vn in keys (m)
540
+ input, _, f, kind = m[vn]
541
+ input_values = get_node_value (m, input)
542
+ value = get (v, vn)
543
+ if kind == :Stochastic || kind == :Observations
544
+ # check whether this is a constrained variable #TODO use bijectors.jl
545
+ lp += logdensityof (f (input_values... ), value)
546
+ end
547
+ end
548
+ lp
336
549
end
0 commit comments