Skip to content

Commit ddbb14c

Browse files
authored
Merge pull request #157 from TuringLang/pmap
Try to avoid `pmap` test issues
2 parents 18ba8db + c67a11f commit ddbb14c

File tree

1 file changed

+37
-32
lines changed

1 file changed

+37
-32
lines changed

test/serialization.jl

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,43 @@
1717
@test mean(samples_m) 0 atol=0.1
1818
end
1919

20-
@testset "pmap" begin
21-
# Add worker processes.
22-
addprocs()
23-
@info "serialization test: using $(nworkers()) processes"
24-
25-
# Load packages on all processes.
26-
@everywhere begin
27-
using DynamicPPL
28-
using Distributions
29-
end
30-
31-
# Define model on all proceses.
32-
@everywhere @model function model()
33-
m ~ Normal(0, 1)
34-
end
35-
36-
# Generate `Model` objects on all processes.
37-
models = pmap(_ -> model(), 1:100)
38-
@test models isa Vector{<:Model}
39-
@test length(models) == 100
40-
41-
# Sample from model on all processes.
42-
n = 1_000
43-
samples1 = pmap(_ -> model()(), 1:n)
44-
m = model()
45-
samples2 = pmap(_ -> m(), 1:n)
46-
47-
for samples in (samples1, samples2)
48-
@test samples isa Vector{Float64}
49-
@test length(samples) == n
50-
@test mean(samples) 0 atol=0.1
51-
@test std(samples) 1 atol=0.1
20+
# Does not work reliably on Travis
21+
if haskey(ENV, "TRAVIS")
22+
@info "Skip `pmap` serialization test"
23+
else
24+
@testset "pmap" begin
25+
# Add worker processes.
26+
addprocs()
27+
@info "serialization test: using $(nworkers()) processes"
28+
29+
# Load packages on all processes.
30+
@everywhere begin
31+
using DynamicPPL
32+
using Distributions
33+
end
34+
35+
# Define model on all proceses.
36+
@everywhere @model function model()
37+
m ~ Normal(0, 1)
38+
end
39+
40+
# Generate `Model` objects on all processes.
41+
models = pmap(_ -> model(), 1:100)
42+
@test models isa Vector{<:Model}
43+
@test length(models) == 100
44+
45+
# Sample from model on all processes.
46+
n = 1_000
47+
samples1 = pmap(_ -> model()(), 1:n)
48+
m = model()
49+
samples2 = pmap(_ -> m(), 1:n)
50+
51+
for samples in (samples1, samples2)
52+
@test samples isa Vector{Float64}
53+
@test length(samples) == n
54+
@test mean(samples) 0 atol=0.1
55+
@test std(samples) 1 atol=0.1
56+
end
5257
end
5358
end
5459
end

0 commit comments

Comments
 (0)