Skip to content

Commit 26d90d7

Browse files
authored
Merge pull request #78 from TuringLang/mt/fix_samplefromuniform
Fix SampleFromUniform
2 parents c685f7d + a31e523 commit 26d90d7

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

src/context_implementations.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,48 @@ function observe(spl::Sampler, weight)
111111
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
112112
end
113113

114+
# If parameters exist, they are used and not overwritten.
114115
function assume(
115-
spl::Union{SampleFromPrior, SampleFromUniform},
116+
spl::SampleFromPrior,
116117
dist::Distribution,
117118
vn::VarName,
118119
vi::VarInfo,
119120
)
120121
if haskey(vi, vn)
121122
if is_flagged(vi, vn, "del")
122123
unset_flag!(vi, vn, "del")
123-
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
124+
r = rand(dist)
124125
vi[vn] = vectorize(dist, r)
126+
settrans!(vi, false, vn)
125127
setorder!(vi, vn, get_num_produce(vi))
126128
else
127-
r = vi[vn]
129+
r = vi[vn]
128130
end
129131
else
130-
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
132+
r = rand(dist)
133+
push!(vi, vn, r, dist, spl)
134+
settrans!(vi, false, vn)
135+
end
136+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
137+
end
138+
139+
# Always overwrites the parameters with new ones.
140+
function assume(
141+
spl::SampleFromUniform,
142+
dist::Distribution,
143+
vn::VarName,
144+
vi::VarInfo,
145+
)
146+
if haskey(vi, vn)
147+
unset_flag!(vi, vn, "del")
148+
r = init(dist)
149+
vi[vn] = vectorize(dist, r)
150+
settrans!(vi, true, vn)
151+
setorder!(vi, vn, get_num_produce(vi))
152+
else
153+
r = init(dist)
131154
push!(vi, vn, r, dist, spl)
155+
settrans!(vi, true, vn)
132156
end
133157
# NOTE: The importance weight is not correctly computed here because
134158
# r is genereated from some uniform distribution which is different from the prior

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ end
117117

118118
# ROBUST INITIALISATIONS
119119
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
120-
randrealuni() = Real(2rand())
121-
randrealuni(args...) = map(Real, 2rand(args...))
120+
randrealuni() = 4 * rand() - 2
121+
randrealuni(args...) = 4 .* rand(args...) .- 2
122122

123123
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
124124

src/varinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`.
821821
The value(s) may or may not be transformed to Euclidean space.
822822
"""
823823
getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi))
824+
getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi))
824825
getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))
825826
function getindex(vi::TypedVarInfo, spl::Sampler)
826827
# Gets the ranges as a NamedTuple

test/varinfo.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DynamicPPL: Selector, reconstruct, invlink, CACHERESET,
66
set_flag!, unset_flag!, VarInfo, TypedVarInfo,
77
getlogp, setlogp!, resetlogp!, acclogp!, vectorize,
88
setorder!, updategid!
9-
using DynamicPPL
9+
using DynamicPPL, LinearAlgebra
1010
using Distributions
1111
using ForwardDiff: Dual
1212
using Test
@@ -167,32 +167,32 @@ include(dir*"/test/test_utils/AllUtils.jl")
167167
meta = vi.metadata
168168
model(vi, SampleFromUniform())
169169

170-
@test all(x -> ~istrans(vi, x), meta.vns)
170+
@test all(x -> istrans(vi, x), meta.vns)
171171
alg = HMC(0.1, 5)
172172
spl = Sampler(alg, model)
173173
v = copy(meta.vals)
174-
link!(vi, spl)
175-
@test all(x -> istrans(vi, x), meta.vns)
176174
invlink!(vi, spl)
177175
@test all(x -> ~istrans(vi, x), meta.vns)
178-
@test meta.vals == v
176+
link!(vi, spl)
177+
@test all(x -> istrans(vi, x), meta.vns)
178+
@test norm(meta.vals - v) <= 1e-6
179179

180180
vi = TypedVarInfo(vi)
181181
meta = vi.metadata
182182
alg = HMC(0.1, 5)
183183
spl = Sampler(alg, model)
184-
@test all(x -> ~istrans(vi, x), meta.s.vns)
185-
@test all(x -> ~istrans(vi, x), meta.m.vns)
186-
v_s = copy(meta.s.vals)
187-
v_m = copy(meta.m.vals)
188-
link!(vi, spl)
189184
@test all(x -> istrans(vi, x), meta.s.vns)
190185
@test all(x -> istrans(vi, x), meta.m.vns)
186+
v_s = copy(meta.s.vals)
187+
v_m = copy(meta.m.vals)
191188
invlink!(vi, spl)
192189
@test all(x -> ~istrans(vi, x), meta.s.vns)
193190
@test all(x -> ~istrans(vi, x), meta.m.vns)
194-
@test meta.s.vals == v_s
195-
@test meta.m.vals == v_m
191+
link!(vi, spl)
192+
@test all(x -> istrans(vi, x), meta.s.vns)
193+
@test all(x -> istrans(vi, x), meta.m.vns)
194+
@test norm(meta.s.vals - v_s) <= 1e-6
195+
@test norm(meta.m.vals - v_m) <= 1e-6
196196
end
197197
@testset "setgid!" begin
198198
vi = VarInfo()

0 commit comments

Comments
 (0)