Skip to content

Commit 06aa6a6

Browse files
authored
Merge pull request #77 from TuringLang/step
Fix SampleFromUniform and implement AbstractMCMC interface for SampleFromPrior and SampleFromUniform
2 parents 1ad116f + 8090bdd commit 06aa6a6

File tree

8 files changed

+113
-65
lines changed

8 files changed

+113
-65
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1213

1314
[compat]

src/DynamicPPL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
66
using MacroTools
7+
8+
import AbstractMCMC
9+
import Random
710
import ZygoteRules
811

912
import Base: Symbol,

src/context_implementations.jl

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -111,56 +111,31 @@ function observe(spl::Sampler, weight)
111111
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
112112
end
113113

114-
# If parameters exist, they are used and not overwritten.
115114
function assume(
116-
spl::SampleFromPrior,
115+
spl::Union{SampleFromPrior,SampleFromUniform},
117116
dist::Distribution,
118117
vn::VarName,
119118
vi::VarInfo,
120119
)
121120
if haskey(vi, vn)
122-
if is_flagged(vi, vn, "del")
121+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
122+
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
123123
unset_flag!(vi, vn, "del")
124-
r = rand(dist)
124+
r = init(dist, spl)
125125
vi[vn] = vectorize(dist, r)
126126
settrans!(vi, false, vn)
127127
setorder!(vi, vn, get_num_produce(vi))
128128
else
129129
r = vi[vn]
130130
end
131131
else
132-
r = rand(dist)
132+
r = init(dist, spl)
133133
push!(vi, vn, r, dist, spl)
134134
settrans!(vi, false, vn)
135135
end
136136
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
137137
end
138138

139-
# Always overwrites the parameters with new ones.
140-
function assume(
141-
spl::SampleFromUniform,
142-
dist::Distribution,
143-
vn::VarName,
144-
vi::VarInfo,
145-
)
146-
if haskey(vi, vn)
147-
unset_flag!(vi, vn, "del")
148-
r = init(dist)
149-
vi[vn] = vectorize(dist, r)
150-
settrans!(vi, true, vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152-
else
153-
r = init(dist)
154-
push!(vi, vn, r, dist, spl)
155-
settrans!(vi, true, vn)
156-
end
157-
# NOTE: The importance weight is not correctly computed here because
158-
# r is genereated from some uniform distribution which is different from the prior
159-
# acclogp!(vi, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)))
160-
161-
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
162-
end
163-
164139
function observe(
165140
spl::Union{SampleFromPrior, SampleFromUniform},
166141
dist::Distribution,
@@ -307,53 +282,60 @@ function get_and_set_val!(
307282
vi::VarInfo,
308283
vns::AbstractVector{<:VarName},
309284
dist::MultivariateDistribution,
310-
spl::AbstractSampler,
285+
spl::Union{SampleFromPrior,SampleFromUniform},
311286
)
312287
n = length(vns)
313288
if haskey(vi, vns[1])
314-
if is_flagged(vi, vns[1], "del")
289+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
290+
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
315291
unset_flag!(vi, vns[1], "del")
316-
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
292+
r = init(dist, spl, n)
317293
for i in 1:n
318294
vn = vns[i]
319295
vi[vn] = vectorize(dist, r[:, i])
296+
settrans!(vi, false, vn)
320297
setorder!(vi, vn, get_num_produce(vi))
321298
end
322299
else
323-
r = vi[vns]
300+
r = vi[vns]
324301
end
325302
else
326-
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
303+
r = init(dist, spl, n)
327304
for i in 1:n
328305
push!(vi, vns[i], r[:,i], dist, spl)
306+
settrans!(vi, false, vn)
329307
end
330308
end
331309
return r
332310
end
311+
333312
function get_and_set_val!(
334313
vi::VarInfo,
335314
vns::AbstractArray{<:VarName},
336315
dists::Union{Distribution, AbstractArray{<:Distribution}},
337-
spl::AbstractSampler,
316+
spl::Union{SampleFromPrior,SampleFromUniform},
338317
)
339318
if haskey(vi, vns[1])
340-
if is_flagged(vi, vns[1], "del")
319+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
320+
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
341321
unset_flag!(vi, vns[1], "del")
342-
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
322+
f = (vn, dist) -> init(dist, spl)
343323
r = f.(vns, dists)
344324
for i in eachindex(vns)
345325
vn = vns[i]
346326
dist = dists isa AbstractArray ? dists[i] : dists
347327
vi[vn] = vectorize(dist, r[i])
328+
settrans!(vi, false, vn)
348329
setorder!(vi, vn, get_num_produce(vi))
349330
end
350331
else
351-
r = reshape(vi[vec(vns)], size(vns))
332+
r = reshape(vi[vec(vns)], size(vns))
352333
end
353334
else
354-
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
335+
f = (vn, dist) -> init(dist, spl)
355336
r = f.(vns, dists)
356337
push!.(Ref(vi), vns, r, dists, Ref(spl))
338+
settrans!.(Ref(vi), false, vns)
357339
end
358340
return r
359341
end

src/sampler.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ struct SampleFromPrior <: AbstractSampler end
66

77
getspace(::Union{SampleFromPrior, SampleFromUniform}) = ()
88

9+
# Initializations.
10+
init(dist, ::SampleFromPrior) = rand(dist)
11+
init(dist, ::SampleFromUniform) = istransformable(dist) ? inittrans(dist) : rand(dist)
12+
13+
init(dist, ::SampleFromPrior, n::Int) = rand(dist, n)
14+
function init(dist, ::SampleFromUniform, n::Int)
15+
return istransformable(dist) ? inittrans(dist, n) : rand(dist, n)
16+
end
17+
918
"""
1019
has_eval_num(spl::AbstractSampler)
1120
@@ -43,3 +52,18 @@ end
4352
Sampler(alg) = Sampler(alg, Selector())
4453
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
4554
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
55+
56+
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
57+
58+
function AbstractMCMC.step!(
59+
rng::Random.AbstractRNG,
60+
model::Model,
61+
sampler::Union{SampleFromUniform,SampleFromPrior},
62+
::Integer,
63+
transition;
64+
kwargs...
65+
)
66+
vi = VarInfo()
67+
model(vi, sampler)
68+
return vi
69+
end

src/utils.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,37 +114,31 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector, n::In
114114
return r
115115
end
116116

117-
118-
# ROBUST INITIALISATIONS
119-
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
117+
# Uniform random numbers with range 4 for robust initializations
118+
# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
120119
randrealuni() = 4 * rand() - 2
121120
randrealuni(args...) = 4 .* rand(args...) .- 2
122121

123-
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
124-
122+
const Transformable = Union{PositiveDistribution,UnitDistribution,TransformDistribution,
123+
SimplexDistribution,PDMatDistribution}
124+
istransformable(dist) = false
125+
istransformable(::Transformable) = true
125126

126127
#################################
127128
# Single-sample initialisations #
128129
#################################
129130

130-
init(dist::Transformable) = inittrans(dist)
131-
init(dist::Distribution) = rand(dist)
132-
133131
inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni())
134-
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1]))
132+
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist, 1)))
135133
inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...))
136134

137-
138135
################################
139136
# Multi-sample initialisations #
140137
################################
141138

142-
init(dist::Transformable, n::Int) = inittrans(dist, n)
143-
init(dist::Distribution, n::Int) = rand(dist, n)
144-
145139
inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n))
146140
function inittrans(dist::MultivariateDistribution, n::Int)
147-
return invlink(dist, randrealuni(size(dist)[1], n))
141+
return invlink(dist, randrealuni(size(dist, 1), n))
148142
end
149143
function inittrans(dist::MatrixDistribution, n::Int)
150144
return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n])

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ turnprogress(false)
99
include("utils.jl")
1010
include("compiler.jl")
1111
include("varinfo.jl")
12+
include("sampler.jl")
1213
include("prob_macro.jl")
1314
include("independence.jl")
1415
end

test/sampler.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using DynamicPPL
2+
using Distributions
3+
using AbstractMCMC: sample
4+
5+
using Random
6+
using Statistics
7+
using Test
8+
9+
Random.seed!(100)
10+
11+
@testset "AbstractMCMC interface" begin
12+
@model gdemo(x, y) = begin
13+
s ~ InverseGamma(2, 3)
14+
m ~ Normal(2.0, sqrt(s))
15+
x ~ Normal(m, sqrt(s))
16+
y ~ Normal(m, sqrt(s))
17+
end
18+
19+
model = gdemo(1.0, 2.0)
20+
N = 1_000
21+
22+
chains = sample(model, SampleFromPrior(), N; progress = false)
23+
@test chains isa Vector{<:VarInfo}
24+
@test length(chains) == N
25+
26+
# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
27+
@test mean(vi[@varname(m)] for vi in chains) 2 atol = 0.1
28+
29+
# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
30+
@test mean(vi[@varname(s)] for vi in chains) 3 atol = 0.1
31+
32+
chains = sample(model, SampleFromUniform(), N; progress = false)
33+
@test chains isa Vector{<:VarInfo}
34+
@test length(chains) == N
35+
36+
# Expected value of ``X`` where ``X ~ U[-2, 2]`` is ≈ 0.
37+
@test mean(vi[@varname(m)] for vi in chains) 0 atol = 0.1
38+
39+
# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
40+
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
41+
end
42+

test/varinfo.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,35 @@ include(dir*"/test/test_utils/AllUtils.jl")
165165

166166
vi = VarInfo()
167167
meta = vi.metadata
168+
168169
model(vi, SampleFromUniform())
170+
@test all(x -> !istrans(vi, x), meta.vns)
169171

170-
@test all(x -> istrans(vi, x), meta.vns)
171172
alg = HMC(0.1, 5)
172173
spl = Sampler(alg, model)
173174
v = copy(meta.vals)
174-
invlink!(vi, spl)
175-
@test all(x -> ~istrans(vi, x), meta.vns)
176175
link!(vi, spl)
177176
@test all(x -> istrans(vi, x), meta.vns)
178-
@test norm(meta.vals - v) <= 1e-6
177+
invlink!(vi, spl)
178+
@test all(x -> !istrans(vi, x), meta.vns)
179+
@test meta.vals == v
179180

180181
vi = TypedVarInfo(vi)
181182
meta = vi.metadata
182183
alg = HMC(0.1, 5)
183184
spl = Sampler(alg, model)
184-
@test all(x -> istrans(vi, x), meta.s.vns)
185-
@test all(x -> istrans(vi, x), meta.m.vns)
185+
@test all(x -> !istrans(vi, x), meta.s.vns)
186+
@test all(x -> !istrans(vi, x), meta.m.vns)
186187
v_s = copy(meta.s.vals)
187188
v_m = copy(meta.m.vals)
188-
invlink!(vi, spl)
189-
@test all(x -> ~istrans(vi, x), meta.s.vns)
190-
@test all(x -> ~istrans(vi, x), meta.m.vns)
191189
link!(vi, spl)
192190
@test all(x -> istrans(vi, x), meta.s.vns)
193191
@test all(x -> istrans(vi, x), meta.m.vns)
194-
@test norm(meta.s.vals - v_s) <= 1e-6
195-
@test norm(meta.m.vals - v_m) <= 1e-6
192+
invlink!(vi, spl)
193+
@test all(x -> ~istrans(vi, x), meta.s.vns)
194+
@test all(x -> ~istrans(vi, x), meta.m.vns)
195+
@test meta.s.vals == v_s
196+
@test meta.m.vals == v_m
196197
end
197198
@testset "setgid!" begin
198199
vi = VarInfo()

0 commit comments

Comments
 (0)