Skip to content

Commit 0f28920

Browse files
sethaxensunxd3
andcommitted
Define rand defaults for AbstractProbabilisticProgram (#79)
This PR adds a 3-arg form of `rand` (suggested by @devmotion in TuringLang/DynamicPPL.jl#466 (comment)) to the interface for `AbstractProbabilisticProgram` and implements the default 1- and 2-arg methods that dispatch to this. Currently tests fail because this breaks the fallbacks for `GraphPPL.Model`, which expects `rand` to forward to its `rand!` method. I'm not certain how we want to define the interface for this `Model`. Co-authored-by: Xianda Sun <[email protected]>
1 parent f788a0a commit 0f28920

File tree

5 files changed

+66
-4
lines changed

5 files changed

+66
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.6.2"
6+
version = "0.6.3"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/abstractprobprog.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AbstractMCMC
22
using DensityInterface
3+
using Random
34

45

56
"""
@@ -60,3 +61,22 @@ m = decondition(condition(m, obs))
6061
should hold for generative models `m` and arbitrary `obs`.
6162
"""
6263
function condition end
64+
65+
66+
"""
67+
rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T
68+
69+
Draw a sample from the joint distribution of the model specified by the probabilistic program.
70+
71+
The sample will be returned as format specified by `T`.
72+
"""
73+
Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram)
74+
function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram)
75+
return rand(rng, NamedTuple, model)
76+
end
77+
function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T}
78+
return rand(Random.default_rng(), T, model)
79+
end
80+
function Base.rand(model::AbstractProbabilisticProgram)
81+
return rand(Random.default_rng(), NamedTuple, model)
82+
end

src/graphinfo.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,9 @@ function Random.rand!(m::AbstractPPL.GraphPPL.Model{T}) where T
444444
end
445445

446446
"""
447-
rand!(rng::AbstractRNG, m::Model)
447+
rand(m::Model)
448448
449-
Draw random samples from the model and mutate the node values.
449+
Draw random samples from the model and return the samples as NamedTuple.
450450
451451
# Examples
452452
@@ -470,11 +470,15 @@ julia> rand(m)
470470
(μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368)
471471
```
472472
"""
473-
function Random.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind}
473+
function Base.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind}
474474
m = deepcopy(sm[])
475475
get_model_values(rand!(rng, m))
476476
end
477477

478+
function Base.rand(rng::AbstractRNG, ::Type{NamedTuple}, m::Model)
479+
rand(rng, Random.SamplerTrivial(m))
480+
end
481+
478482
"""
479483
logdensityof(m::Model)
480484

test/abstractprobprog.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using AbstractPPL
2+
using Random
3+
using Test
4+
5+
mutable struct RandModel <: AbstractProbabilisticProgram
6+
rng
7+
T
8+
end
9+
10+
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where {T}
11+
model.rng = rng
12+
model.T = T
13+
return nothing
14+
end
15+
16+
@testset "AbstractProbabilisticProgram" begin
17+
@testset "rand defaults" begin
18+
model = RandModel(nothing, nothing)
19+
rand(model)
20+
@test model.rng == Random.default_rng()
21+
@test model.T === NamedTuple
22+
rngs = [Random.default_rng(), Random.MersenneTwister(42)]
23+
Ts = [NamedTuple, Dict]
24+
@testset for T in Ts
25+
model = RandModel(nothing, nothing)
26+
rand(T, model)
27+
@test model.rng == Random.default_rng()
28+
@test model.T === T
29+
end
30+
@testset for rng in rngs
31+
model = RandModel(nothing, nothing)
32+
rand(rng, model)
33+
@test model.rng === rng
34+
@test model.T === NamedTuple
35+
end
36+
end
37+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using Test
1313
@testset "AbstractPPL.jl" begin
1414
include("deprecations.jl")
1515
include("varname.jl")
16+
include("abstractprobprog.jl")
1617
include("graphinfo/graphinfo.jl")
1718
@testset "doctests" begin
1819
DocMeta.setdocmeta!(

0 commit comments

Comments
 (0)