Skip to content

Commit 8518a58

Browse files
Remove internal call
1 parent bf11785 commit 8518a58

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

ext/AdvancedPSLibtaskExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
# PG requires keeping all randomness for the reference particle
8888
# Create new task and copy randomness
8989
function AdvancedPS.forkr(trace::LibtaskTrace)
90-
newf = AdvancedPS.reset_model(trace.model.ctask.fargs[1])
90+
newf = AdvancedPS.reset_model(trace.model.f)
9191
Random123.set_counter!(trace.rng, 1)
9292

9393
ctask = Libtask.TapedTask(trace.rng, newf)
@@ -109,7 +109,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
109109
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
110110
"""
111111
function AdvancedPS.observe(dist::Distributions.Distribution, x)
112-
return Libtask.produce(Distributions.loglikelihood(dist, x))
112+
Libtask.produce(Distributions.loglikelihood(dist, x))
113+
return nothing
113114
end
114115

115116
"""
@@ -149,7 +150,7 @@ function AbstractMCMC.step(
149150

150151
replayed = AdvancedPS.replay(newtrajectory)
151152
return AdvancedPS.PGSample(replayed.model.f, logevidence),
152-
AdvancedPS.PGState(newtrajectory)
153+
AdvancedPS.PGState(replayed)
153154
end
154155

155156
function AbstractMCMC.sample(

src/container.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,12 @@ function resample_propagate!(
206206

207207
Random.seed!(p.rng, seeds[1])
208208

209-
children[j += 1] = p
209+
children[j+=1] = p
210210
# fork additional children
211211
for k in 2:ni
212212
part = fork(p, isref)
213213
Random.seed!(part.rng, seeds[k])
214-
children[j += 1] = part
214+
children[j+=1] = part
215215
end
216216
end
217217
end

test/smc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
sampler = AdvancedPS.SMC(15, 0.6)
88
@test sampler.nparticles == 15
99
@test sampler.resampler ===
10-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
10+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
1111

1212
sampler = AdvancedPS.SMC(20, AdvancedPS.resample_multinomial, 0.6)
1313
@test sampler.nparticles == 20
1414
@test sampler.resampler ===
15-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
15+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
1616

1717
sampler = AdvancedPS.SMC(25, AdvancedPS.resample_systematic)
1818
@test sampler.nparticles == 25
@@ -98,7 +98,7 @@
9898
return AdvancedPS.observe(Bernoulli(x / 2), 0)
9999
end
100100

101-
chains_smc = sample(TestModel(), AdvancedPS.SMC(100))
101+
chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false)
102102

103103
@test all(isone(particle.x) for particle in chains_smc.trajectories)
104104
@test chains_smc.logevidence -2 * log(2)
@@ -112,12 +112,12 @@
112112
sampler = AdvancedPS.PG(60, 0.6)
113113
@test sampler.nparticles == 60
114114
@test sampler.resampler ===
115-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
115+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
116116

117117
sampler = AdvancedPS.PG(80, AdvancedPS.resample_multinomial, 0.6)
118118
@test sampler.nparticles == 80
119119
@test sampler.resampler ===
120-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
120+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
121121

122122
sampler = AdvancedPS.PG(100, AdvancedPS.resample_systematic)
123123
@test sampler.nparticles == 100
@@ -152,7 +152,7 @@
152152
return AdvancedPS.observe(Bernoulli(x / 2), 0)
153153
end
154154

155-
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100)
155+
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false)
156156

157157
@test all(isone(p.trajectory.x) for p in chains_pg)
158158
@test mean(x.logevidence for x in chains_pg) -2 * log(2) atol = 0.01
@@ -177,7 +177,7 @@
177177
end
178178

179179
pg = AdvancedPS.PG(1)
180-
first, second = sample(DummyModel(), pg, 2)
180+
first, second = sample(DummyModel(), pg, 2; progress=false)
181181

182182
first_model = first.trajectory
183183
second_model = second.trajectory

0 commit comments

Comments
 (0)