|
581 | 581 | @test all(iszero(model()) for _ in 1:1000)
|
582 | 582 | end
|
583 | 583 | @testset "threading" begin
|
584 |
| - @info "Peforming threading tests with $(Threads.nthreads()) threads" |
| 584 | + println("Peforming threading tests with $(Threads.nthreads()) threads") |
585 | 585 |
|
586 | 586 | x = rand(10_000)
|
587 | 587 |
|
|
596 | 596 | wthreads(x)(vi)
|
597 | 597 | lp_w_threads = getlogp(vi)
|
598 | 598 |
|
599 |
| - println("With threading:") |
| 599 | + println("With `@threads`:") |
| 600 | + println(" default:") |
600 | 601 | @time wthreads(x)(vi)
|
601 | 602 |
|
| 603 | + # Ensure that we use `ThreadSafeVarInfo`. |
| 604 | + @test getlogp(vi) ≈ lp_w_threads |
| 605 | + DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), |
| 606 | + DefaultContext()) |
| 607 | + |
| 608 | + println(" evaluate_multithreaded:") |
| 609 | + @time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), |
| 610 | + DefaultContext()) |
| 611 | + |
602 | 612 | @model function wothreads(x)
|
603 | 613 | x[1] ~ Normal(0, 1)
|
604 | 614 | for i in 2:length(x)
|
|
610 | 620 | wothreads(x)(vi)
|
611 | 621 | lp_wo_threads = getlogp(vi)
|
612 | 622 |
|
613 |
| - println("Without threading:") |
| 623 | + println("Without `@threads`:") |
| 624 | + println(" default:") |
614 | 625 | @time wothreads(x)(vi)
|
615 | 626 |
|
616 | 627 | @test lp_w_threads ≈ lp_wo_threads
|
| 628 | + |
| 629 | + # Ensure that we use `VarInfo`. |
| 630 | + DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), |
| 631 | + DefaultContext()) |
| 632 | + @test getlogp(vi) ≈ lp_w_threads |
| 633 | + |
| 634 | + println(" evaluate_singlethreaded:") |
| 635 | + @time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), |
| 636 | + DefaultContext()) |
617 | 637 | end
|
618 | 638 | end
|
0 commit comments