Skip to content

Commit 8090bdd

Browse files
committed
Fix SampleFromUniform
1 parent 80a54f1 commit 8090bdd

File tree

5 files changed

+65
-71
lines changed

5 files changed

+65
-71
lines changed

src/context_implementations.jl

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -111,56 +111,31 @@ function observe(spl::Sampler, weight)
111111
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
112112
end
113113

114-
# If parameters exist, they are used and not overwritten.
115114
function assume(
116-
spl::SampleFromPrior,
115+
spl::Union{SampleFromPrior,SampleFromUniform},
117116
dist::Distribution,
118117
vn::VarName,
119118
vi::VarInfo,
120119
)
121120
if haskey(vi, vn)
122-
if is_flagged(vi, vn, "del")
121+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
122+
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
123123
unset_flag!(vi, vn, "del")
124-
r = rand(dist)
124+
r = init(dist, spl)
125125
vi[vn] = vectorize(dist, r)
126126
settrans!(vi, false, vn)
127127
setorder!(vi, vn, get_num_produce(vi))
128128
else
129129
r = vi[vn]
130130
end
131131
else
132-
r = rand(dist)
132+
r = init(dist, spl)
133133
push!(vi, vn, r, dist, spl)
134134
settrans!(vi, false, vn)
135135
end
136136
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
137137
end
138138

139-
# Always overwrites the parameters with new ones.
140-
function assume(
141-
spl::SampleFromUniform,
142-
dist::Distribution,
143-
vn::VarName,
144-
vi::VarInfo,
145-
)
146-
if haskey(vi, vn)
147-
unset_flag!(vi, vn, "del")
148-
r = init(dist)
149-
vi[vn] = vectorize(dist, r)
150-
settrans!(vi, true, vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152-
else
153-
r = init(dist)
154-
push!(vi, vn, r, dist, spl)
155-
settrans!(vi, true, vn)
156-
end
157-
# NOTE: The importance weight is not correctly computed here because
158-
# r is genereated from some uniform distribution which is different from the prior
159-
# acclogp!(vi, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)))
160-
161-
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
162-
end
163-
164139
function observe(
165140
spl::Union{SampleFromPrior, SampleFromUniform},
166141
dist::Distribution,
@@ -307,53 +282,60 @@ function get_and_set_val!(
307282
vi::VarInfo,
308283
vns::AbstractVector{<:VarName},
309284
dist::MultivariateDistribution,
310-
spl::AbstractSampler,
285+
spl::Union{SampleFromPrior,SampleFromUniform},
311286
)
312287
n = length(vns)
313288
if haskey(vi, vns[1])
314-
if is_flagged(vi, vns[1], "del")
289+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
290+
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
315291
unset_flag!(vi, vns[1], "del")
316-
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
292+
r = init(dist, spl, n)
317293
for i in 1:n
318294
vn = vns[i]
319295
vi[vn] = vectorize(dist, r[:, i])
296+
settrans!(vi, false, vn)
320297
setorder!(vi, vn, get_num_produce(vi))
321298
end
322299
else
323-
r = vi[vns]
300+
r = vi[vns]
324301
end
325302
else
326-
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
303+
r = init(dist, spl, n)
327304
for i in 1:n
328305
push!(vi, vns[i], r[:,i], dist, spl)
306+
settrans!(vi, false, vn)
329307
end
330308
end
331309
return r
332310
end
311+
333312
function get_and_set_val!(
334313
vi::VarInfo,
335314
vns::AbstractArray{<:VarName},
336315
dists::Union{Distribution, AbstractArray{<:Distribution}},
337-
spl::AbstractSampler,
316+
spl::Union{SampleFromPrior,SampleFromUniform},
338317
)
339318
if haskey(vi, vns[1])
340-
if is_flagged(vi, vns[1], "del")
319+
# Always overwrite the parameters with new ones for `SampleFromUniform`.
320+
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
341321
unset_flag!(vi, vns[1], "del")
342-
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
322+
f = (vn, dist) -> init(dist, spl)
343323
r = f.(vns, dists)
344324
for i in eachindex(vns)
345325
vn = vns[i]
346326
dist = dists isa AbstractArray ? dists[i] : dists
347327
vi[vn] = vectorize(dist, r[i])
328+
settrans!(vi, false, vn)
348329
setorder!(vi, vn, get_num_produce(vi))
349330
end
350331
else
351-
r = reshape(vi[vec(vns)], size(vns))
332+
r = reshape(vi[vec(vns)], size(vns))
352333
end
353334
else
354-
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
335+
f = (vn, dist) -> init(dist, spl)
355336
r = f.(vns, dists)
356337
push!.(Ref(vi), vns, r, dists, Ref(spl))
338+
settrans!.(Ref(vi), false, vns)
357339
end
358340
return r
359341
end

src/sampler.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ struct SampleFromPrior <: AbstractSampler end
66

77
getspace(::Union{SampleFromPrior, SampleFromUniform}) = ()
88

9+
# Initializations.
10+
init(dist, ::SampleFromPrior) = rand(dist)
11+
init(dist, ::SampleFromUniform) = istransformable(dist) ? inittrans(dist) : rand(dist)
12+
13+
init(dist, ::SampleFromPrior, n::Int) = rand(dist, n)
14+
function init(dist, ::SampleFromUniform, n::Int)
15+
return istransformable(dist) ? inittrans(dist, n) : rand(dist, n)
16+
end
17+
918
"""
1019
has_eval_num(spl::AbstractSampler)
1120

src/utils.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,37 +114,31 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector, n::In
114114
return r
115115
end
116116

117-
118-
# ROBUST INITIALISATIONS
119-
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
117+
# Uniform random numbers with range 4 for robust initializations
118+
# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
120119
randrealuni() = 4 * rand() - 2
121120
randrealuni(args...) = 4 .* rand(args...) .- 2
122121

123-
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
124-
122+
const Transformable = Union{PositiveDistribution,UnitDistribution,TransformDistribution,
123+
SimplexDistribution,PDMatDistribution}
124+
istransformable(dist) = false
125+
istransformable(::Transformable) = true
125126

126127
#################################
127128
# Single-sample initialisations #
128129
#################################
129130

130-
init(dist::Transformable) = inittrans(dist)
131-
init(dist::Distribution) = rand(dist)
132-
133131
inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni())
134-
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1]))
132+
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist, 1)))
135133
inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...))
136134

137-
138135
################################
139136
# Multi-sample initialisations #
140137
################################
141138

142-
init(dist::Transformable, n::Int) = inittrans(dist, n)
143-
init(dist::Distribution, n::Int) = rand(dist, n)
144-
145139
inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n))
146140
function inittrans(dist::MultivariateDistribution, n::Int)
147-
return invlink(dist, randrealuni(size(dist)[1], n))
141+
return invlink(dist, randrealuni(size(dist, 1), n))
148142
end
149143
function inittrans(dist::MatrixDistribution, n::Int)
150144
return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n])

test/sampler.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,37 @@ using Random
66
using Statistics
77
using Test
88

9-
Random.seed!(1234)
9+
Random.seed!(100)
1010

1111
@testset "AbstractMCMC interface" begin
1212
@model gdemo(x, y) = begin
1313
s ~ InverseGamma(2, 3)
14-
m ~ Normal(0.0, sqrt(s))
14+
m ~ Normal(2.0, sqrt(s))
1515
x ~ Normal(m, sqrt(s))
1616
y ~ Normal(m, sqrt(s))
1717
end
1818

1919
model = gdemo(1.0, 2.0)
20-
N = 10_000
20+
N = 1_000
2121

2222
chains = sample(model, SampleFromPrior(), N; progress = false)
2323
@test chains isa Vector{<:VarInfo}
2424
@test length(chains) == N
25-
@test mean(vi[@varname(m)] for vi in chains) 0 atol = 0.1
25+
26+
# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
27+
@test mean(vi[@varname(m)] for vi in chains) 2 atol = 0.1
28+
29+
# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
2630
@test mean(vi[@varname(s)] for vi in chains) 3 atol = 0.1
2731

2832
chains = sample(model, SampleFromUniform(), N; progress = false)
2933
@test chains isa Vector{<:VarInfo}
3034
@test length(chains) == N
31-
@test mean(vi[@varname(m)] for vi in chains) 1 atol = 0.1
32-
@test mean(vi[@varname(s)] for vi in chains) 3.3 atol = 0.1
35+
36+
# Expected value of ``X`` where ``X ~ U[-2, 2]`` is ≈ 0.
37+
@test mean(vi[@varname(m)] for vi in chains) 0 atol = 0.1
38+
39+
# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
40+
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
3341
end
3442

test/varinfo.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,35 @@ include(dir*"/test/test_utils/AllUtils.jl")
165165

166166
vi = VarInfo()
167167
meta = vi.metadata
168+
168169
model(vi, SampleFromUniform())
170+
@test all(x -> !istrans(vi, x), meta.vns)
169171

170-
@test all(x -> istrans(vi, x), meta.vns)
171172
alg = HMC(0.1, 5)
172173
spl = Sampler(alg, model)
173174
v = copy(meta.vals)
174-
invlink!(vi, spl)
175-
@test all(x -> ~istrans(vi, x), meta.vns)
176175
link!(vi, spl)
177176
@test all(x -> istrans(vi, x), meta.vns)
178-
@test norm(meta.vals - v) <= 1e-6
177+
invlink!(vi, spl)
178+
@test all(x -> !istrans(vi, x), meta.vns)
179+
@test meta.vals == v
179180

180181
vi = TypedVarInfo(vi)
181182
meta = vi.metadata
182183
alg = HMC(0.1, 5)
183184
spl = Sampler(alg, model)
184-
@test all(x -> istrans(vi, x), meta.s.vns)
185-
@test all(x -> istrans(vi, x), meta.m.vns)
185+
@test all(x -> !istrans(vi, x), meta.s.vns)
186+
@test all(x -> !istrans(vi, x), meta.m.vns)
186187
v_s = copy(meta.s.vals)
187188
v_m = copy(meta.m.vals)
188-
invlink!(vi, spl)
189-
@test all(x -> ~istrans(vi, x), meta.s.vns)
190-
@test all(x -> ~istrans(vi, x), meta.m.vns)
191189
link!(vi, spl)
192190
@test all(x -> istrans(vi, x), meta.s.vns)
193191
@test all(x -> istrans(vi, x), meta.m.vns)
194-
@test norm(meta.s.vals - v_s) <= 1e-6
195-
@test norm(meta.m.vals - v_m) <= 1e-6
192+
invlink!(vi, spl)
193+
@test all(x -> ~istrans(vi, x), meta.s.vns)
194+
@test all(x -> ~istrans(vi, x), meta.m.vns)
195+
@test meta.s.vals == v_s
196+
@test meta.m.vals == v_m
196197
end
197198
@testset "setgid!" begin
198199
vi = VarInfo()

0 commit comments

Comments
 (0)