|
162 | 162 | end
|
163 | 163 |
|
164 | 164 | # initial parameters
|
165 |
| - init_params = [(b=randn(), a=rand()) for _ in 1:100] |
| 165 | + nchains = 100 |
| 166 | + init_params = [(b=randn(), a=rand()) for _ in 1:nchains] |
166 | 167 | chains = sample(
|
167 | 168 | MyModel(),
|
168 | 169 | MySampler(),
|
169 | 170 | MCMCThreads(),
|
170 | 171 | 3,
|
171 |
| - 100; |
| 172 | + nchains; |
172 | 173 | progress=false,
|
173 | 174 | init_params=init_params,
|
174 | 175 | )
|
175 |
| - @test length(chains) == 100 |
| 176 | + @test length(chains) == nchains |
176 | 177 | @test all(
|
177 | 178 | chain[1].a == params.a && chain[1].b == params.b for
|
178 | 179 | (chain, params) in zip(chains, init_params)
|
|
184 | 185 | MySampler(),
|
185 | 186 | MCMCThreads(),
|
186 | 187 | 3,
|
187 |
| - 100; |
| 188 | + nchains; |
188 | 189 | progress=false,
|
189 |
| - init_params=Iterators.repeated(init_params), |
| 190 | + init_params=FillArrays.Fill(init_params, nchains), |
190 | 191 | )
|
191 |
| - @test length(chains) == 100 |
| 192 | + @test length(chains) == nchains |
192 | 193 | @test all(
|
193 | 194 | chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
|
194 | 195 | )
|
| 196 | + |
| 197 | + # Too many `init_params` |
| 198 | + @test_throws ArgumentError sample( |
| 199 | + MyModel(), |
| 200 | + MySampler(), |
| 201 | + MCMCThreads(), |
| 202 | + 3, |
| 203 | + nchains; |
| 204 | + progress=false, |
| 205 | + init_params=FillArrays.Fill(init_params, nchains + 1), |
| 206 | + ) |
| 207 | + |
| 208 | + # Too few `init_params` |
| 209 | + @test_throws ArgumentError sample( |
| 210 | + MyModel(), |
| 211 | + MySampler(), |
| 212 | + MCMCThreads(), |
| 213 | + 3, |
| 214 | + nchains; |
| 215 | + progress=false, |
| 216 | + init_params=FillArrays.Fill(init_params, nchains - 1), |
| 217 | + ) |
195 | 218 | end
|
196 | 219 |
|
197 | 220 | @testset "Multicore sampling" begin
|
|
274 | 297 | @test all(l.level > Logging.LogLevel(-1) for l in logs)
|
275 | 298 |
|
276 | 299 | # initial parameters
|
277 |
| - init_params = [(a=randn(), b=rand()) for _ in 1:100] |
| 300 | + nchains = 100 |
| 301 | + init_params = [(a=randn(), b=rand()) for _ in 1:nchains] |
278 | 302 | chains = sample(
|
279 | 303 | MyModel(),
|
280 | 304 | MySampler(),
|
281 | 305 | MCMCDistributed(),
|
282 | 306 | 3,
|
283 |
| - 100; |
| 307 | + nchains; |
284 | 308 | progress=false,
|
285 | 309 | init_params=init_params,
|
286 | 310 | )
|
287 |
| - @test length(chains) == 100 |
| 311 | + @test length(chains) == nchains |
288 | 312 | @test all(
|
289 | 313 | chain[1].a == params.a && chain[1].b == params.b for
|
290 | 314 | (chain, params) in zip(chains, init_params)
|
|
296 | 320 | MySampler(),
|
297 | 321 | MCMCDistributed(),
|
298 | 322 | 3,
|
299 |
| - 100; |
| 323 | + nchains; |
300 | 324 | progress=false,
|
301 |
| - init_params=Iterators.repeated(init_params), |
| 325 | + init_params=FillArrays.Fill(init_params, nchains), |
302 | 326 | )
|
303 |
| - @test length(chains) == 100 |
| 327 | + @test length(chains) == nchains |
304 | 328 | @test all(
|
305 | 329 | chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
|
306 | 330 | )
|
307 | 331 |
|
| 332 | + # Too many `init_params` |
| 333 | + @test_throws ArgumentError sample( |
| 334 | + MyModel(), |
| 335 | + MySampler(), |
| 336 | + MCMCDistributed(), |
| 337 | + 3, |
| 338 | + nchains; |
| 339 | + progress=false, |
| 340 | + init_params=FillArrays.Fill(init_params, nchains + 1), |
| 341 | + ) |
| 342 | + |
| 343 | + # Too few `init_params` |
| 344 | + @test_throws ArgumentError sample( |
| 345 | + MyModel(), |
| 346 | + MySampler(), |
| 347 | + MCMCDistributed(), |
| 348 | + 3, |
| 349 | + nchains; |
| 350 | + progress=false, |
| 351 | + init_params=FillArrays.Fill(init_params, nchains - 1), |
| 352 | + ) |
| 353 | + |
308 | 354 | # Remove workers
|
309 | 355 | rmprocs(pids...)
|
310 | 356 | end
|
|
360 | 406 | @test all(l.level > Logging.LogLevel(-1) for l in logs)
|
361 | 407 |
|
362 | 408 | # initial parameters
|
363 |
| - init_params = [(a=rand(), b=randn()) for _ in 1:100] |
| 409 | + nchains = 100 |
| 410 | + init_params = [(a=rand(), b=randn()) for _ in 1:nchains] |
364 | 411 | chains = sample(
|
365 | 412 | MyModel(),
|
366 | 413 | MySampler(),
|
367 | 414 | MCMCSerial(),
|
368 | 415 | 3,
|
369 |
| - 100; |
| 416 | + nchains; |
370 | 417 | progress=false,
|
371 | 418 | init_params=init_params,
|
372 | 419 | )
|
373 |
| - @test length(chains) == 100 |
| 420 | + @test length(chains) == nchains |
374 | 421 | @test all(
|
375 | 422 | chain[1].a == params.a && chain[1].b == params.b for
|
376 | 423 | (chain, params) in zip(chains, init_params)
|
|
382 | 429 | MySampler(),
|
383 | 430 | MCMCSerial(),
|
384 | 431 | 3,
|
385 |
| - 100; |
| 432 | + nchains; |
386 | 433 | progress=false,
|
387 |
| - init_params=Iterators.repeated(init_params), |
| 434 | + init_params=FillArrays.Fill(init_params, nchains), |
388 | 435 | )
|
389 |
| - @test length(chains) == 100 |
| 436 | + @test length(chains) == nchains |
390 | 437 | @test all(
|
391 | 438 | chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
|
392 | 439 | )
|
| 440 | + |
| 441 | + # Too many `init_params` |
| 442 | + @test_throws ArgumentError sample( |
| 443 | + MyModel(), |
| 444 | + MySampler(), |
| 445 | + MCMCSerial(), |
| 446 | + 3, |
| 447 | + nchains; |
| 448 | + progress=false, |
| 449 | + init_params=FillArrays.Fill(init_params, nchains + 1), |
| 450 | + ) |
| 451 | + |
| 452 | + # Too few `init_params` |
| 453 | + @test_throws ArgumentError sample( |
| 454 | + MyModel(), |
| 455 | + MySampler(), |
| 456 | + MCMCSerial(), |
| 457 | + 3, |
| 458 | + nchains; |
| 459 | + progress=false, |
| 460 | + init_params=FillArrays.Fill(init_params, nchains - 1), |
| 461 | + ) |
393 | 462 | end
|
394 | 463 |
|
395 | 464 | @testset "Ensemble sampling: Reproducibility" begin
|
|
564 | 633 | @test ismissing(chain[1].a)
|
565 | 634 | @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol = 6e-2
|
566 | 635 | @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol = 1e-2
|
567 |
| - @test mean(x.b for x in chain) ≈ 0 atol = 0.1 |
| 636 | + @test mean(x.b for x in chain) ≈ 0 atol = 0.11 |
568 | 637 | @test var(x.b for x in chain) ≈ 1 atol = 0.15
|
569 | 638 | end
|
570 | 639 |
|
|
0 commit comments