Skip to content

Commit 7c64297

Browse files
committed
Ensure that both evaluate_singlethreaded and evaluate_multithreaded are tested
1 parent 462c32f commit 7c64297

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

test/compiler.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ end
581581
@test all(iszero(model()) for _ in 1:1000)
582582
end
583583
@testset "threading" begin
584-
@info "Peforming threading tests with $(Threads.nthreads()) threads"
584+
println("Peforming threading tests with $(Threads.nthreads()) threads")
585585

586586
x = rand(10_000)
587587

@@ -596,9 +596,19 @@ end
596596
wthreads(x)(vi)
597597
lp_w_threads = getlogp(vi)
598598

599-
println("With threading:")
599+
println("With `@threads`:")
600+
println(" default:")
600601
@time wthreads(x)(vi)
601602

603+
# Ensure that we use `ThreadSafeVarInfo`.
604+
@test getlogp(vi) lp_w_threads
605+
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
606+
DefaultContext())
607+
608+
println(" evaluate_multithreaded:")
609+
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
610+
DefaultContext())
611+
602612
@model function wothreads(x)
603613
x[1] ~ Normal(0, 1)
604614
for i in 2:length(x)
@@ -610,9 +620,19 @@ end
610620
wothreads(x)(vi)
611621
lp_wo_threads = getlogp(vi)
612622

613-
println("Without threading:")
623+
println("Without `@threads`:")
624+
println(" default:")
614625
@time wothreads(x)(vi)
615626

616627
@test lp_w_threads lp_wo_threads
628+
629+
# Ensure that we use `VarInfo`.
630+
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
631+
DefaultContext())
632+
@test getlogp(vi) lp_w_threads
633+
634+
println(" evaluate_singlethreaded:")
635+
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
636+
DefaultContext())
617637
end
618638
end

0 commit comments

Comments
 (0)