Skip to content

Commit 444cfd5

Browse files
committed
Implement AbstractMCMC interface for SampleFromPrior and SampleFromUniform
1 parent b7159cb commit 444cfd5

File tree

5 files changed

+54
-0
lines changed

5 files changed

+54
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1213

1314
[compat]

src/DynamicPPL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
66
using MacroTools
7+
8+
import AbstractMCMC
9+
import Random
710
import ZygoteRules
811

912
import Base: Symbol,

src/sampler.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,18 @@ end
4343
Sampler(alg) = Sampler(alg, Selector())
4444
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
4545
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
46+
47+
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
48+
49+
function AbstractMCMC.step!(
50+
rng::Random.AbstractRNG,
51+
model::Model,
52+
sampler::Union{SampleFromUniform,SampleFromPrior},
53+
::Integer,
54+
transition;
55+
kwargs...
56+
)
57+
vi = VarInfo()
58+
model(vi, sampler)
59+
return vi
60+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ turnprogress(false)
99
include("utils.jl")
1010
include("compiler.jl")
1111
include("varinfo.jl")
12+
include("sampler.jl")
1213
include("prob_macro.jl")
1314
include("independence.jl")
1415
end

test/sampler.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using DynamicPPL
2+
using Distributions
3+
using AbstractMCMC: sample
4+
5+
using Random
6+
using Statistics
7+
using Test
8+
9+
Random.seed!(1234)
10+
11+
@testset "AbstractMCMC interface" begin
12+
@model gdemo(x, y) = begin
13+
s ~ InverseGamma(2, 3)
14+
m ~ Normal(0.0, sqrt(s))
15+
x ~ Normal(m, sqrt(s))
16+
y ~ Normal(m, sqrt(s))
17+
end
18+
19+
model = gdemo(1.0, 2.0)
20+
N = 10_000
21+
22+
chains = sample(model, SampleFromPrior(), N; progress = false)
23+
@test chains isa Vector{<:VarInfo}
24+
@test length(chains) == N
25+
@test mean(vi[@varname(m)] for vi in chains) 0 atol = 0.1
26+
@test mean(vi[@varname(s)] for vi in chains) 3 atol = 0.1
27+
28+
chains = sample(model, SampleFromUniform(), N; progress = false)
29+
@test chains isa Vector{<:VarInfo}
30+
@test length(chains) == N
31+
@test mean(vi[@varname(m)] for vi in chains) 1 atol = 0.1
32+
@test mean(vi[@varname(s)] for vi in chains) 3.3 atol = 0.1
33+
end
34+

0 commit comments

Comments
 (0)