Skip to content

Commit c655d8c

Browse files
Carlos Paradayebaidevmotion
committed
Add rand for sampling from prior of Models (#381)
A simple+intuitively-named way to generate a prior sample as a way to initialize a sampler. Co-authored-by: Hong Ge <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 083dfa1 commit c655d8c

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.17.7"
3+
version = "0.17.8"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,27 @@ Get the name of the `model` as `Symbol`.
517517
"""
518518
Base.nameof(model::Model) = model.name
519519

520+
"""
521+
rand([rng=Random.GLOBAL_RNG], [T=NamedTuple], model::Model)
522+
523+
Generate a sample of type `T` from the prior distribution of the `model`.
524+
"""
525+
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
526+
x = last(
527+
evaluate!!(
528+
model,
529+
SimpleVarInfo{Float64}(),
530+
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
531+
),
532+
)
533+
return DynamicPPL.values_as(x, T)
534+
end
535+
536+
# Default RNG and type
537+
Base.rand(rng::Random.AbstractRNG, model::Model) = rand(rng, NamedTuple, model)
538+
Base.rand(::Type{T}, model::Model) where {T} = rand(Random.GLOBAL_RNG, T, model)
539+
Base.rand(model::Model) = rand(Random.GLOBAL_RNG, NamedTuple, model)
540+
520541
"""
521542
logjoint(model::Model, varinfo::AbstractVarInfo)
522543

test/model.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,26 @@
8181
call_retval = model()
8282
@test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval))
8383
end
84+
85+
@testset "rand" begin
86+
model = gdemo_default
87+
88+
Random.seed!(1776)
89+
s, m = model()
90+
sample_namedtuple = (; s=s, m=m)
91+
sample_dict = Dict(:s => s, :m => m)
92+
93+
# With explicit RNG
94+
@test rand(Random.seed!(1776), model) == sample_namedtuple
95+
@test rand(Random.seed!(1776), NamedTuple, model) == sample_namedtuple
96+
@test rand(Random.seed!(1776), Dict, model) == sample_dict
97+
98+
# Without explicit RNG
99+
Random.seed!(1776)
100+
@test rand(model) == sample_namedtuple
101+
Random.seed!(1776)
102+
@test rand(NamedTuple, model) == sample_namedtuple
103+
Random.seed!(1776)
104+
@test rand(Dict, model) == sample_dict
105+
end
84106
end

0 commit comments

Comments
 (0)