Skip to content

Commit 7cc7233

Browse files
authored
Implementing evaluation interface (#63)
* adding get_node_value function for multiple nodes * rough version of sampling from model * adding rand interface * switching from getproperty to setfield and adding new non-mutating rand * adding function to set model values * adding log density functions for model * fixing type signature for rand, changing logdensityof for Model * adding sampler things and refactoring test dir * refactoring directories and moving mh to example * removing examples folder, temporarily moving mh to test dir * deleting mh files (for this pr) * ` * adding docstrings and tests for random and logdensityof * removing some comments * removing .DS_Store, moving Distributions.jl to test env * fixing indentations * fixing identations in test * adding AbstractMCMC bound * updating julia compat * removing AbstractMCMC from test file * updating test project.toml
1 parent ed16f90 commit 7cc7233

File tree

5 files changed

+279
-19
lines changed

5 files changed

+279
-19
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ version = "0.5.2"
88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
1010
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1213
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1314

1415
[compat]
1516
AbstractMCMC = "2, 3, 4"
1617
DensityInterface = "0.4"
1718
Setfield = "0.7.1, 0.8"
18-
julia = "1.6"
19+
julia = "~1.6.6, 1.7.3"

src/graphinfo.jl

Lines changed: 222 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import Base.getindex
33
using SparseArrays
44
using Setfield
55
using Setfield: PropertyLens, get
6+
using DensityInterface
7+
using Random
68

79
"""
810
GraphInfo
@@ -222,7 +224,7 @@ function Base.getindex(m::Model, vn::VarName)
222224
end
223225

224226
"""
225-
set_node_value!(m::Model, ind::VarName, value::T) where Takes
227+
set_node_value!(m::Model, ind::VarName, value::T) where T
226228
227229
Change the value of the node.
228230
@@ -231,7 +233,7 @@ Change the value of the node.
231233
```jl-doctest
232234
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233235
μ = (1.0, () -> 1.0, :Logical),
234-
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
236+
y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
235237
Nodes:
236238
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237239
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
@@ -271,31 +273,58 @@ julia> get_node_value(m, @varname s2)
271273
"""
272274

273275
function get_node_value(m::Model, ind::VarName)
274-
v = getproperty(m[ind], :value)
276+
v = get(m[ind], @lens _.value)
275277
v[]
276278
end
277-
#Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
279+
280+
function get_node_value(m::Model, ind::NTuple{N, Symbol}) where N
281+
# [get_node_value(m, VarName{S}()) for S in ind]
282+
values = Vector{Union{Float64, Array{Float64}}}()
283+
for i in ind
284+
push!(values, get_node_value(m, VarName{i}()))
285+
end
286+
values
287+
end
288+
289+
"""
290+
get_node_ref_value(m::Model, ind::VarName)
291+
get_node_ref_value(m::Model, ind::NTuple{N, Symbol})
292+
293+
Return the mutable Ref value associated with a node or tuple
294+
of nodes.
295+
"""
296+
function get_node_ref_value(m::Model, ind::VarName)
297+
get(m[ind], @lens _.value)
298+
end
299+
300+
function get_node_ref_value(m::Model, ind::NTuple{N, Symbol}) where N
301+
values = Vector{Union{Base.RefValue{Float64}, Base.RefValue{Vector{Float64}}}}()
302+
for i in ind
303+
push!(values, get_node_ref_value(m, VarName{i}()))
304+
end
305+
values
306+
end
278307

279308
"""
280309
get_node_input(m::Model, ind::VarName)
281310
282311
Retrieve the inputs/parents of a node, as given by model DAG.
283312
"""
284-
get_node_input(m::Model, ind::VarName) = getproperty(m[ind], :input)
313+
get_node_input(m::Model, ind::VarName) = get(m[ind], @lens _.input)
285314

286315
"""
287316
get_node_input(m::Model, ind::VarName)
288317
289318
Retrieve the evaluation function for a node.
290319
"""
291-
get_node_eval(m::Model, ind::VarName) = getproperty(m[ind], :eval)
320+
get_node_eval(m::Model, ind::VarName) = get(m[ind], @lens _.eval)
292321

293322
"""
294323
get_nodekind(m::Model, ind::VarName)
295324
296325
Retrieve the type of the node, i.e. stochastic or logical.
297326
"""
298-
get_nodekind(m::Model, ind::VarName) = getproperty(m[ind], :kind)
327+
get_nodekind(m::Model, ind::VarName) = get(m[ind], @lens _.kind)
299328

300329
"""
301330
get_dag(m::Model)
@@ -310,16 +339,48 @@ get_dag(m::Model) = m.g.A
310339
Returns a `Vector{Symbol}` containing the sorted vertices
311340
of the DAG.
312341
"""
313-
get_sorted_vertices(m::Model) = getproperty(m.g, :sorted_vertices)
342+
get_sorted_vertices(m::Model) = get(m.g, @lens _.sorted_vertices)
343+
344+
345+
"""
346+
get_model_values(m::Model)
347+
348+
Returns a Named Tuple of nodes and node values.
349+
"""
350+
function get_model_values(m::Model{T}) where T
351+
NamedTuple{T}(get_node_value(m, T))
352+
end
353+
354+
"""
355+
get_model_ref_values(m::Model)
356+
357+
Returns a Named Tuple of nodes and node Ref values.
358+
"""
359+
function get_model_ref_values(m::Model{T}) where T
360+
NamedTuple{T}(get_node_ref_value(m, T))
361+
end
362+
363+
"""
364+
set_model_values!(m::Model, values::NamedTuple)
314365
366+
Changes the values of the `Model` node values to those
367+
given by a Named Tuple of node symboles and new values.
368+
"""
369+
function set_model_values!(m::Model{T}, values::NamedTuple{T}) where T
370+
for vn in keys(m)
371+
if get_nodekind(m, vn) != :Observations
372+
set_node_value!(m, vn, get(values, vn))
373+
end
374+
end
375+
end
315376
# iterators
316377

317378
function Base.iterate(m::Model, state=1)
318379
state > length(get_sorted_vertices(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
319380
end
320381

321382
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
322-
Base.IteratorEltype(m::Model) = HasEltype()
383+
Base.IteratorEltype(m::Model) = Base.HasEltype()
323384

324385
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
325386
Base.values(m::Model) = Base.Generator(identity, m)
@@ -333,4 +394,156 @@ function Base.show(io::IO, m::Model)
333394
for node in get_sorted_vertices(m)
334395
print(io, "$node = ", m[VarName{node}()], "\n")
335396
end
397+
end
398+
399+
"""
400+
rand!(rng::AbstractRNG, m::Model)
401+
402+
Draw random samples from the model and mutate the node values.
403+
404+
# Examples
405+
406+
```jl-doctest
407+
julia> import AbstractPPL.GraphPPL: Model, rand!
408+
using Distributions
409+
410+
julia> using Random; Random.seed!(1234)
411+
TaskLocalRNG()
412+
413+
julia> m = Model(s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
414+
μ = (1.0, () -> 1.0, :Logical),
415+
y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
416+
Nodes:
417+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
418+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
419+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
420+
421+
422+
julia> rand!(m)
423+
Nodes:
424+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
425+
s2 = (input = (), value = Base.RefValue{Float64}(2.7478186975593846), eval = var"#5#8"(), kind = :Stochastic)
426+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.3044653509044275), eval = var"#7#10"(), kind = :Stochastic)
427+
```
428+
"""
429+
function Random.rand!(rng::AbstractRNG, m::AbstractPPL.GraphPPL.Model{T}) where T
430+
for vn in keys(m)
431+
input, _, f, kind = m[vn]
432+
input_values = get_node_value(m, input)
433+
if kind == :Stochastic || kind == :Observations
434+
set_node_value!(m, vn, rand(rng, f(input_values...)))
435+
else
436+
set_node_value!(m, vn, f(input_values...))
437+
end
438+
end
439+
m
440+
end
441+
442+
function Random.rand!(m::AbstractPPL.GraphPPL.Model{T}) where T
443+
rand!(Random.GLOBAL_RNG, m)
444+
end
445+
446+
"""
447+
rand!(rng::AbstractRNG, m::Model)
448+
449+
Draw random samples from the model and mutate the node values.
450+
451+
# Examples
452+
453+
```jl-doctest
454+
julia> using Random; Random.seed!(1234)
455+
456+
julia> import AbstractPPL.GraphPPL: Model, rand
457+
[ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
458+
459+
julia> using Distributions
460+
461+
julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
462+
μ = (0.0, () -> 1.0, :Logical),
463+
y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
464+
Nodes:
465+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
466+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
467+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
468+
469+
julia> rand(m)
470+
(μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368)
471+
```
472+
"""
473+
function Random.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind}
474+
m = deepcopy(sm[])
475+
get_model_values(rand!(rng, m))
476+
end
477+
478+
"""
479+
logdensityof(m::Model)
480+
481+
Evaluate the log-densinty of the model.
482+
483+
# Examples
484+
485+
```jl-doctest
486+
julia> using Random; Random.seed!(1234)
487+
MersenneTwister(1234)
488+
489+
julia> import AbstractPPL.GraphPPL: Model, logdensityof
490+
[ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
491+
492+
julia> using Distributions
493+
494+
julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
495+
μ = (0.0, () -> 1.0, :Logical),
496+
y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
497+
Nodes:
498+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
499+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
500+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
501+
502+
julia> logdensityof(m)
503+
-1.721713955868453
504+
```
505+
"""
506+
function DensityInterface.logdensityof(m::AbstractPPL.GraphPPL.Model)
507+
logdensityof(m, get_model_values(m))
508+
end
509+
510+
"""
511+
logdensityof(m::Model{T}, v::NamedTuple{T})
512+
513+
Evaluate the log-densinty of the model.
514+
515+
# Examples
516+
517+
```jl-doctest
518+
julia> using Random; Random.seed!(1234)
519+
MersenneTwister(1234)
520+
521+
julia> import AbstractPPL.GraphPPL: Model, logdensityof, get_model_values
522+
[ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
523+
524+
julia> using Distributions
525+
526+
julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
527+
μ = (0.0, () -> 1.0, :Logical),
528+
y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
529+
Nodes:
530+
μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
531+
s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
532+
y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
533+
534+
julia> logdensityof(m, get_model_values(m))
535+
-1.721713955868453
536+
"""
537+
function DensityInterface.logdensityof(m::AbstractPPL.GraphPPL.Model{T}, v::NamedTuple{T, V}) where {T, V}
538+
lp = 0.0
539+
for vn in keys(m)
540+
input, _, f, kind = m[vn]
541+
input_values = get_node_value(m, input)
542+
value = get(v, vn)
543+
if kind == :Stochastic || kind == :Observations
544+
# check whether this is a constrained variable #TODO use bijectors.jl
545+
lp += logdensityof(f(input_values...), value)
546+
end
547+
end
548+
lp
336549
end

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
[deps]
2+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
36
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
47
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
58
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)