@@ -312,8 +312,8 @@ function mcmcsample(
312
312
# Create a seed for each chain using the provided random number generator.
313
313
seeds = rand (rng, UInt, nchains)
314
314
315
- # Ensure that initial parameters are `nothing` or indexable
316
- _init_params = _first_or_nothing (init_params, nchains)
315
+ # Ensure that initial parameters are `nothing` or of the correct length
316
+ check_initial_params (init_params, nchains)
317
317
318
318
# Set up a chains vector.
319
319
chains = Vector {Any} (undef, nchains)
@@ -364,10 +364,10 @@ function mcmcsample(
364
364
_sampler,
365
365
N;
366
366
progress= false ,
367
- init_params= if _init_params === nothing
367
+ init_params= if init_params === nothing
368
368
nothing
369
369
else
370
- _init_params [chainidx]
370
+ init_params [chainidx]
371
371
end ,
372
372
kwargs... ,
373
373
)
@@ -410,8 +410,8 @@ function mcmcsample(
410
410
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
411
411
end
412
412
413
- # Ensure that initial parameters are `nothing` or indexable
414
- _init_params = _first_or_nothing (init_params, nchains)
413
+ # Ensure that initial parameters are `nothing` or of the correct length
414
+ check_initial_params (init_params, nchains)
415
415
416
416
# Create a seed for each chain using the provided random number generator.
417
417
seeds = rand (rng, UInt, nchains)
@@ -469,10 +469,10 @@ function mcmcsample(
469
469
# Return the new chain.
470
470
return chain
471
471
end
472
- chains = if _init_params === nothing
472
+ chains = if init_params === nothing
473
473
Distributed. pmap (sample_chain, pool, seeds)
474
474
else
475
- Distributed. pmap (sample_chain, pool, seeds, _init_params )
475
+ Distributed. pmap (sample_chain, pool, seeds, init_params )
476
476
end
477
477
finally
478
478
# Stop updating the progress bar.
@@ -502,8 +502,8 @@ function mcmcsample(
502
502
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
503
503
end
504
504
505
- # Ensure that initial parameters are `nothing` or indexable
506
- _init_params = _first_or_nothing (init_params, nchains)
505
+ # Ensure that initial parameters are `nothing` or of the correct length
506
+ check_initial_params (init_params, nchains)
507
507
508
508
# Create a seed for each chain using the provided random number generator.
509
509
seeds = rand (rng, UInt, nchains)
@@ -525,10 +525,10 @@ function mcmcsample(
525
525
)
526
526
end
527
527
528
- chains = if _init_params === nothing
528
+ chains = if init_params === nothing
529
529
map (sample_chain, 1 : nchains, seeds)
530
530
else
531
- map (sample_chain, 1 : nchains, seeds, _init_params )
531
+ map (sample_chain, 1 : nchains, seeds, init_params )
532
532
end
533
533
534
534
# Concatenate the chains together.
538
538
tighten_eltype (x) = x
539
539
tighten_eltype (x:: Vector{Any} ) = map (identity, x)
540
540
541
- """
542
- _first_or_nothing(x, n::Int)
543
-
544
- Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
545
-
546
- If `x !== nothing`, then `x` has to contain at least `n` elements.
547
- """
548
- function _first_or_nothing (x, n:: Int )
549
- y = _first (x, n)
550
- length (y) == n || throw (
541
+ check_initial_params (x:: Nothing , n:: Int ) = nothing
542
+ function check_initial_params (x, n:: Int )
543
+ length (x) == n || throw (
551
544
ArgumentError (" not enough initial parameters (expected $n , received $(length (y)) " ),
552
545
)
553
- return y
554
- end
555
- _first_or_nothing (:: Nothing , :: Int ) = nothing
556
-
557
- # `first(x, n::Int)` requires Julia 1.6
558
- function _first (x, n:: Int )
559
- @static if VERSION >= v " 1.6.0-DEV.431"
560
- first (x, n)
561
- else
562
- if x isa AbstractVector
563
- @inbounds x[firstindex (x): min (firstindex (x) + n - 1 , lastindex (x))]
564
- else
565
- collect (Iterators. take (x, n))
566
- end
567
- end
568
546
end
0 commit comments