Skip to content

Commit 610e676

Browse files
committed
Fix tests
1 parent e21ee63 commit 610e676

File tree

6 files changed

+31
-24
lines changed

6 files changed

+31
-24
lines changed

test/Turing/inference/AdvancedSMC.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ function AbstractMCMC.sample_end!(
308308
end
309309

310310
function DynamicPPL.assume(
311+
rng,
311312
spl::Sampler{<:Union{PG,SMC}},
312313
dist::Distribution,
313314
vn::VarName,
@@ -316,11 +317,11 @@ function DynamicPPL.assume(
316317
vi = current_trace().vi
317318
if inspace(vn, spl)
318319
if ~haskey(vi, vn)
319-
r = rand(dist)
320+
r = rand(rng, dist)
320321
push!(vi, vn, r, dist, spl)
321322
elseif is_flagged(vi, vn, "del")
322323
unset_flag!(vi, vn, "del")
323-
r = rand(dist)
324+
r = rand(rng, dist)
324325
vi[vn] = vectorize(dist, r)
325326
setgid!(vi, spl.selector, vn)
326327
setorder!(vi, vn, get_num_produce(vi))
@@ -332,7 +333,7 @@ function DynamicPPL.assume(
332333
if haskey(vi, vn)
333334
r = vi[vn]
334335
else
335-
r = rand(dist)
336+
r = rand(rng, dist)
336337
push!(vi, vn, r, dist, Selector(:invalid))
337338
end
338339
lp = logpdf_with_trans(dist, r, istrans(vi, vn))

test/Turing/inference/ess.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ES
117117
vi = spl.state.vi
118118
vns = _getvns(vi, spl)
119119
set_flag!(vi, vns[1][1], "del")
120-
model.model(vi, spl)
120+
model.model(rng, vi, spl)
121121
return vi[spl]
122122
end
123123

@@ -144,26 +144,26 @@ function Distributions.loglikelihood(model::ESSModel, f)
144144
getlogp(vi)
145145
end
146146

147-
function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
147+
function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
148148
if inspace(vn, sampler)
149-
return DynamicPPL.tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
149+
return DynamicPPL.tilde(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
150150
else
151-
return DynamicPPL.tilde(ctx, SampleFromPrior(), right, vn, inds, vi)
151+
return DynamicPPL.tilde(rng, ctx, SampleFromPrior(), right, vn, inds, vi)
152152
end
153153
end
154154

155155
function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
156156
return DynamicPPL.tilde(ctx, SampleFromPrior(), right, left, vi)
157157
end
158158

159-
function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
159+
function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
160160
if inspace(vn, sampler)
161-
return DynamicPPL.dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
161+
return DynamicPPL.dot_tilde(rng, LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
162162
else
163-
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi)
163+
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vn, inds, vi)
164164
end
165165
end
166166

167-
function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
168-
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi)
167+
function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
168+
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi)
169169
end

test/Turing/inference/hmc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max
507507
#### Compiler interface, i.e. tilde operators.
508508
####
509509
function DynamicPPL.assume(
510+
rng,
510511
spl::Sampler{<:Hamiltonian},
511512
dist::Distribution,
512513
vn::VarName,
@@ -524,6 +525,7 @@ function DynamicPPL.assume(
524525
end
525526

526527
function DynamicPPL.dot_assume(
528+
rng,
527529
spl::Sampler{<:Hamiltonian},
528530
dist::MultivariateDistribution,
529531
vns::AbstractArray{<:VarName},
@@ -537,6 +539,7 @@ function DynamicPPL.dot_assume(
537539
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
538540
end
539541
function DynamicPPL.dot_assume(
542+
rng,
540543
spl::Sampler{<:Hamiltonian},
541544
dists::Union{Distribution, AbstractArray{<:Distribution}},
542545
vns::AbstractArray{<:VarName},

test/Turing/inference/is.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ function Sampler(alg::IS, model::Model, s::Selector)
4545
end
4646

4747
function AbstractMCMC.step!(
48-
::AbstractRNG,
48+
rng::AbstractRNG,
4949
model::Model,
5050
spl::Sampler{<:IS},
5151
::Integer,
5252
transition;
5353
kwargs...
5454
)
5555
empty!(spl.state.vi)
56-
model(spl.state.vi, spl)
56+
model(rng, spl.state.vi, spl)
5757

5858
return Transition(spl)
5959
end
@@ -70,8 +70,8 @@ function AbstractMCMC.sample_end!(
7070
spl.state.final_logevidence = logsumexp(map(x->x.lp, ts)) - log(N)
7171
end
7272

73-
function DynamicPPL.assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
74-
r = rand(dist)
73+
function DynamicPPL.assume(rng, spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
74+
r = rand(rng, dist)
7575
push!(vi, vn, r, dist, spl)
7676
return r, 0
7777
end

test/Turing/inference/mh.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ end
223223
#### Compiler interface, i.e. tilde operators.
224224
####
225225
function DynamicPPL.assume(
226+
rng,
226227
spl::Sampler{<:MH},
227228
dist::Distribution,
228229
vn::VarName,
@@ -234,6 +235,7 @@ function DynamicPPL.assume(
234235
end
235236

236237
function DynamicPPL.dot_assume(
238+
rng,
237239
spl::Sampler{<:MH},
238240
dist::MultivariateDistribution,
239241
vn::VarName,
@@ -249,6 +251,7 @@ function DynamicPPL.dot_assume(
249251
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
250252
end
251253
function DynamicPPL.dot_assume(
254+
rng,
252255
spl::Sampler{<:MH},
253256
dists::Union{Distribution, AbstractArray{<:Distribution}},
254257
vn::VarName,

test/threadsafe.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@
5555

5656
# Ensure that we use `ThreadSafeVarInfo`.
5757
@test getlogp(vi) lp_w_threads
58-
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
59-
DefaultContext())
58+
DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi,
59+
SampleFromPrior(), DefaultContext())
6060

6161
println(" evaluate_multithreaded:")
62-
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
63-
DefaultContext())
62+
@time DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi,
63+
SampleFromPrior(), DefaultContext())
6464

6565
@model function wothreads(x)
6666
x[1] ~ Normal(0, 1)
@@ -80,12 +80,12 @@
8080
@test lp_w_threads lp_wo_threads
8181

8282
# Ensure that we use `VarInfo`.
83-
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
84-
DefaultContext())
83+
DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi,
84+
SampleFromPrior(), DefaultContext())
8585
@test getlogp(vi) lp_w_threads
8686

8787
println(" evaluate_singlethreaded:")
88-
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
89-
DefaultContext())
88+
@time DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi,
89+
SampleFromPrior(), DefaultContext())
9090
end
9191
end

0 commit comments

Comments
 (0)