Skip to content

Commit e0840a9

Browse files
committed
fix SampleFromUniform
1 parent b7159cb commit e0840a9

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

src/context_implementations.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,48 @@ function observe(spl::Sampler, weight)
103103
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
104104
end
105105

106+
# If parameters exist, they are used and not overwritten.
106107
function assume(
107-
spl::Union{SampleFromPrior, SampleFromUniform},
108+
spl::SampleFromPrior,
108109
dist::Distribution,
109110
vn::VarName,
110111
vi::VarInfo,
111112
)
112113
if haskey(vi, vn)
113114
if is_flagged(vi, vn, "del")
114115
unset_flag!(vi, vn, "del")
115-
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
116+
r = rand(dist)
116117
vi[vn] = vectorize(dist, r)
118+
settrans!(vi, false, vn)
117119
setorder!(vi, vn, get_num_produce(vi))
118120
else
119-
r = vi[vn]
121+
r = vi[vn]
120122
end
121123
else
122-
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
124+
r = rand(dist)
125+
push!(vi, vn, r, dist, spl)
126+
settrans!(vi, false, vn)
127+
end
128+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
129+
end
130+
131+
# Always overwrites the parameters with new ones.
132+
function assume(
133+
spl::SampleFromUniform,
134+
dist::Distribution,
135+
vn::VarName,
136+
vi::VarInfo,
137+
)
138+
if haskey(vi, vn)
139+
unset_flag!(vi, vn, "del")
140+
r = init(dist)
141+
vi[vn] = vectorize(dist, r)
142+
settrans!(vi, true, vn)
143+
setorder!(vi, vn, get_num_produce(vi))
144+
else
145+
r = init(dist)
123146
push!(vi, vn, r, dist, spl)
147+
settrans!(vi, true, vn)
124148
end
125149
# NOTE: The importance weight is not correctly computed here because
126150
# r is genereated from some uniform distribution which is different from the prior

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ end
117117

118118
# ROBUST INITIALISATIONS
119119
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
120-
randrealuni() = Real(2rand())
121-
randrealuni(args...) = map(Real, 2rand(args...))
120+
randrealuni() = Real(4rand()-2)
121+
randrealuni(args...) = map(Real, 4rand(args...)-2)
122122

123123
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
124124

0 commit comments

Comments
 (0)