1
1
using AbstractMCMC
2
- using AbstractMCMC: sample, psample, steps!
2
+ using AbstractMCMC: steps!
3
3
using Atom. Progress: JunoProgressLogger
4
4
using ConsoleProgressMonitor: ProgressLogger
5
5
using IJulia
6
6
using LoggingExtras: TeeLogger, EarlyFilteredLogger
7
7
using TerminalLoggers: TerminalLogger
8
8
9
+ using Distributed
9
10
import Logging
10
11
using Random
11
12
using Statistics
@@ -99,14 +100,26 @@ include("interface.jl")
99
100
@test first (LOGGERS) === logger
100
101
@test Logging. current_logger () === CURRENT_LOGGER
101
102
end
103
+
104
+ @testset " Suppress output" begin
105
+ logs, _ = collect_test_logs (; min_level= Logging. LogLevel (- 1 )) do
106
+ sample (MyModel (), MySampler (), 100 ; progress = false , sleepy = true )
107
+ end
108
+ @test all (l. level > Logging. LogLevel (- 1 ) for l in logs)
109
+ end
102
110
end
103
111
104
112
if VERSION ≥ v " 1.3"
105
- @testset " Parallel sampling" begin
106
- println (" testing parallel sampling with " , Threads. nthreads (), " thread(s)..." )
113
+ @testset " Multithreaded sampling" begin
114
+ if Threads. nthreads () == 1
115
+ warnregex = r" ^Only a single thread available"
116
+ @test_logs (:warn , warnregex) sample (MyModel (), MySampler (), MCMCThreads (),
117
+ 10 , 10 ; chain_type = MyChain)
118
+ end
107
119
108
120
Random. seed! (1234 )
109
- chains = psample (MyModel (), MySampler (), 10_000 , 1000 ; chain_type = MyChain)
121
+ chains = sample (MyModel (), MySampler (), MCMCThreads (), 10_000 , 1000 ;
122
+ chain_type = MyChain)
110
123
111
124
# test output type and size
112
125
@test chains isa Vector{MyChain}
@@ -121,12 +134,69 @@ include("interface.jl")
121
134
122
135
# test reproducibility
123
136
Random. seed! (1234 )
124
- chains2 = psample (MyModel (), MySampler (), 10_000 , 1000 ; chain_type = MyChain)
137
+ chains2 = sample (MyModel (), MySampler (), MCMCThreads (), 10_000 , 1000 ;
138
+ chain_type = MyChain)
125
139
126
140
@test all (((x, y),) -> x. as == y. as && x. bs == y. bs, zip (chains, chains2))
141
+
142
+ # Suppress output.
143
+ logs, _ = collect_test_logs (; min_level= Logging. LogLevel (- 1 )) do
144
+ sample (MyModel (), MySampler (), MCMCThreads (), 10_000 , 1000 ;
145
+ progress = false , chain_type = MyChain)
146
+ end
147
+ @test all (l. level > Logging. LogLevel (- 1 ) for l in logs)
127
148
end
128
149
end
129
150
151
+ @testset " Multicore sampling" begin
152
+ if nworkers () == 1
153
+ warnregex = r" ^Only a single process available"
154
+ @test_logs (:warn , warnregex) sample (MyModel (), MySampler (), MCMCDistributed (),
155
+ 10 , 10 ; chain_type = MyChain)
156
+ end
157
+
158
+ # Add worker processes.
159
+ addprocs ()
160
+
161
+ # Load all required packages (`interface.jl` needs Random).
162
+ @everywhere begin
163
+ using AbstractMCMC
164
+ using AbstractMCMC: sample
165
+
166
+ using Random
167
+ include (" interface.jl" )
168
+ end
169
+
170
+ Random. seed! (1234 )
171
+ chains = sample (MyModel (), MySampler (), MCMCDistributed (), 10_000 , 1000 ;
172
+ chain_type = MyChain)
173
+
174
+ # Test output type and size.
175
+ @test chains isa Vector{MyChain}
176
+ @test length (chains) == 1000
177
+ @test all (x -> length (x. as) == length (x. bs) == 10_000 , chains)
178
+
179
+ # Test some statistical properties.
180
+ @test all (x -> isapprox (mean (x. as), 0.5 ; atol= 1e-2 ), chains)
181
+ @test all (x -> isapprox (var (x. as), 1 / 12 ; atol= 5e-3 ), chains)
182
+ @test all (x -> isapprox (mean (x. bs), 0 ; atol= 5e-2 ), chains)
183
+ @test all (x -> isapprox (var (x. bs), 1 ; atol= 5e-2 ), chains)
184
+
185
+ # Test reproducibility.
186
+ Random. seed! (1234 )
187
+ chains2 = sample (MyModel (), MySampler (), MCMCDistributed (), 10_000 , 1000 ;
188
+ chain_type = MyChain)
189
+
190
+ @test all (((x, y),) -> x. as == y. as && x. bs == y. bs, zip (chains, chains2))
191
+
192
+ # Suppress output.
193
+ logs, _ = collect_test_logs (; min_level= Logging. LogLevel (- 1 )) do
194
+ sample (MyModel (), MySampler (), MCMCDistributed (), 10_000 , 100 ;
195
+ progress = false , chain_type = MyChain)
196
+ end
197
+ @test all (l. level > Logging. LogLevel (- 1 ) for l in logs)
198
+ end
199
+
130
200
@testset " Chain constructors" begin
131
201
chain1 = sample (MyModel (), MySampler (), 100 ; sleepy = true )
132
202
chain2 = sample (MyModel (), MySampler (), 100 ; sleepy = true , chain_type = MyChain)
@@ -135,21 +205,6 @@ include("interface.jl")
135
205
@test chain2 isa MyChain
136
206
end
137
207
138
- @testset " Suppress output" begin
139
- logs, _ = collect_test_logs (; min_level= Logging. LogLevel (- 1 )) do
140
- sample (MyModel (), MySampler (), 100 ; progress = false , sleepy = true )
141
- end
142
- @test isempty (logs)
143
-
144
- if VERSION ≥ v " 1.3"
145
- logs, _ = collect_test_logs (; min_level= Logging. LogLevel (- 1 )) do
146
- psample (MyModel (), MySampler (), 10_000 , 1000 ;
147
- progress = false , chain_type = MyChain)
148
- end
149
- @test isempty (logs)
150
- end
151
- end
152
-
153
208
@testset " Iterator sampling" begin
154
209
Random. seed! (1234 )
155
210
as = []
@@ -182,4 +237,15 @@ include("interface.jl")
182
237
bmean = mean (x. b for x in chain)
183
238
@test abs (bmean) <= 0.001 && length (chain) < 10_000
184
239
end
240
+
241
+ @testset " Deprecations" begin
242
+ @test_deprecated AbstractMCMC. psample (MyModel (), MySampler (), 10 , 10 ;
243
+ chain_type = MyChain)
244
+ @test_deprecated AbstractMCMC. psample (Random. GLOBAL_RNG, MyModel (), MySampler (),
245
+ 10 , 10 ;
246
+ chain_type = MyChain)
247
+ @test_deprecated AbstractMCMC. mcmcpsample (Random. GLOBAL_RNG, MyModel (),
248
+ MySampler (), 10 , 10 ;
249
+ chain_type = MyChain)
250
+ end
185
251
end
0 commit comments