@@ -103,24 +103,48 @@ function observe(spl::Sampler, weight)
103
103
error (" DynamicPPL.observe: unmanaged inference algorithm: $(typeof (spl)) " )
104
104
end
105
105
106
+ # If parameters exist, they are used and not overwritten.
106
107
function assume (
107
- spl:: Union{ SampleFromPrior, SampleFromUniform} ,
108
+ spl:: SampleFromPrior ,
108
109
dist:: Distribution ,
109
110
vn:: VarName ,
110
111
vi:: VarInfo ,
111
112
)
112
113
if haskey (vi, vn)
113
114
if is_flagged (vi, vn, " del" )
114
115
unset_flag! (vi, vn, " del" )
115
- r = spl isa SampleFromUniform ? init (dist) : rand (dist)
116
+ r = rand (dist)
116
117
vi[vn] = vectorize (dist, r)
118
+ settrans! (vi, false , vn)
117
119
setorder! (vi, vn, get_num_produce (vi))
118
120
else
119
- r = vi[vn]
121
+ r = vi[vn]
120
122
end
121
123
else
122
- r = isa (spl, SampleFromUniform) ? init (dist) : rand (dist)
124
+ r = rand (dist)
125
+ push! (vi, vn, r, dist, spl)
126
+ settrans! (vi, false , vn)
127
+ end
128
+ return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn))
129
+ end
130
+
131
+ # Always overwrites the parameters with new ones.
132
+ function assume (
133
+ spl:: SampleFromUniform ,
134
+ dist:: Distribution ,
135
+ vn:: VarName ,
136
+ vi:: VarInfo ,
137
+ )
138
+ if haskey (vi, vn)
139
+ unset_flag! (vi, vn, " del" )
140
+ r = init (dist)
141
+ vi[vn] = vectorize (dist, r)
142
+ settrans! (vi, true , vn)
143
+ setorder! (vi, vn, get_num_produce (vi))
144
+ else
145
+ r = init (dist)
123
146
push! (vi, vn, r, dist, spl)
147
+ settrans! (vi, true , vn)
124
148
end
125
149
# NOTE: The importance weight is not correctly computed here because
126
150
# r is genereated from some uniform distribution which is different from the prior
0 commit comments