|
39 | 39 | x = rand(10_000)
|
40 | 40 |
|
41 | 41 | @model function wthreads(x)
|
| 42 | + global vi_ = _varinfo |
42 | 43 | x[1] ~ Normal(0, 1)
|
43 | 44 | Threads.@threads for i in 2:length(x)
|
44 | 45 | x[i] ~ Normal(x[i-1], 1)
|
|
48 | 49 | vi = VarInfo()
|
49 | 50 | wthreads(x)(vi)
|
50 | 51 | lp_w_threads = getlogp(vi)
|
| 52 | + if Threads.nthreads() == 1 |
| 53 | + @test vi_ isa VarInfo |
| 54 | + else |
| 55 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
| 56 | + end |
51 | 57 |
|
52 | 58 | println("With `@threads`:")
|
53 | 59 | println(" default:")
|
54 | 60 | @time wthreads(x)(vi)
|
55 | 61 |
|
56 |
| - # Ensure that we use `ThreadSafeVarInfo`. |
| 62 | + # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded assume statements. |
| 63 | + DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, |
| 64 | + SampleFromPrior(), DefaultContext()) |
57 | 65 | @test getlogp(vi) ≈ lp_w_threads
|
58 |
| - DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi, |
59 |
| - SampleFromPrior(), DefaultContext()) |
| 66 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
60 | 67 |
|
61 |
| - println(" evaluate_multithreaded:") |
62 |
| - @time DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi, |
63 |
| - SampleFromPrior(), DefaultContext()) |
| 68 | + println(" evaluate_threadsafe:") |
| 69 | + @time DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, |
| 70 | + SampleFromPrior(), DefaultContext()) |
64 | 71 |
|
65 | 72 | @model function wothreads(x)
|
| 73 | + global vi_ = _varinfo |
66 | 74 | x[1] ~ Normal(0, 1)
|
67 | 75 | for i in 2:length(x)
|
68 | 76 | x[i] ~ Normal(x[i-1], 1)
|
|
72 | 80 | vi = VarInfo()
|
73 | 81 | wothreads(x)(vi)
|
74 | 82 | lp_wo_threads = getlogp(vi)
|
| 83 | + if Threads.nthreads() == 1 |
| 84 | + @test vi_ isa VarInfo |
| 85 | + else |
| 86 | + @test vi_ isa DynamicPPL.ThreadSafeVarInfo |
| 87 | + end |
75 | 88 |
|
76 | 89 | println("Without `@threads`:")
|
77 | 90 | println(" default:")
|
|
80 | 93 | @test lp_w_threads ≈ lp_wo_threads
|
81 | 94 |
|
82 | 95 | # Ensure that we use `VarInfo`.
|
83 |
| - DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi, |
84 |
| - SampleFromPrior(), DefaultContext()) |
| 96 | + DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, |
| 97 | + SampleFromPrior(), DefaultContext()) |
85 | 98 | @test getlogp(vi) ≈ lp_w_threads
|
| 99 | + @test vi_ isa VarInfo |
86 | 100 |
|
87 |
| - println(" evaluate_singlethreaded:") |
88 |
| - @time DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi, |
89 |
| - SampleFromPrior(), DefaultContext()) |
| 101 | + println(" evaluate_threadunsafe:") |
| 102 | + @time DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, |
| 103 | + SampleFromPrior(), DefaultContext()) |
90 | 104 | end
|
91 | 105 | end
|
0 commit comments