Skip to content

Commit e897b8a

Browse files
committed
remove _first_or_nothing and just check if init_params is of the right length
1 parent 2e6e23d commit e897b8a

File tree

1 file changed

+15
-37
lines changed

1 file changed

+15
-37
lines changed

src/sample.jl

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ function mcmcsample(
312312
# Create a seed for each chain using the provided random number generator.
313313
seeds = rand(rng, UInt, nchains)
314314

315-
# Ensure that initial parameters are `nothing` or indexable
316-
_init_params = _first_or_nothing(init_params, nchains)
315+
# Ensure that initial parameters are `nothing` or of the correct length
316+
check_initial_params(init_params, nchains)
317317

318318
# Set up a chains vector.
319319
chains = Vector{Any}(undef, nchains)
@@ -364,10 +364,10 @@ function mcmcsample(
364364
_sampler,
365365
N;
366366
progress=false,
367-
init_params=if _init_params === nothing
367+
init_params=if init_params === nothing
368368
nothing
369369
else
370-
_init_params[chainidx]
370+
init_params[chainidx]
371371
end,
372372
kwargs...,
373373
)
@@ -410,8 +410,8 @@ function mcmcsample(
410410
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
411411
end
412412

413-
# Ensure that initial parameters are `nothing` or indexable
414-
_init_params = _first_or_nothing(init_params, nchains)
413+
# Ensure that initial parameters are `nothing` or of the correct length
414+
check_initial_params(init_params, nchains)
415415

416416
# Create a seed for each chain using the provided random number generator.
417417
seeds = rand(rng, UInt, nchains)
@@ -469,10 +469,10 @@ function mcmcsample(
469469
# Return the new chain.
470470
return chain
471471
end
472-
chains = if _init_params === nothing
472+
chains = if init_params === nothing
473473
Distributed.pmap(sample_chain, pool, seeds)
474474
else
475-
Distributed.pmap(sample_chain, pool, seeds, _init_params)
475+
Distributed.pmap(sample_chain, pool, seeds, init_params)
476476
end
477477
finally
478478
# Stop updating the progress bar.
@@ -502,8 +502,8 @@ function mcmcsample(
502502
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
503503
end
504504

505-
# Ensure that initial parameters are `nothing` or indexable
506-
_init_params = _first_or_nothing(init_params, nchains)
505+
# Ensure that initial parameters are `nothing` or of the correct length
506+
check_initial_params(init_params, nchains)
507507

508508
# Create a seed for each chain using the provided random number generator.
509509
seeds = rand(rng, UInt, nchains)
@@ -525,10 +525,10 @@ function mcmcsample(
525525
)
526526
end
527527

528-
chains = if _init_params === nothing
528+
chains = if init_params === nothing
529529
map(sample_chain, 1:nchains, seeds)
530530
else
531-
map(sample_chain, 1:nchains, seeds, _init_params)
531+
map(sample_chain, 1:nchains, seeds, init_params)
532532
end
533533

534534
# Concatenate the chains together.
@@ -538,31 +538,9 @@ end
538538
tighten_eltype(x) = x
539539
tighten_eltype(x::Vector{Any}) = map(identity, x)
540540

541-
"""
542-
_first_or_nothing(x, n::Int)
543-
544-
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
545-
546-
If `x !== nothing`, then `x` has to contain at least `n` elements.
547-
"""
548-
function _first_or_nothing(x, n::Int)
549-
y = _first(x, n)
550-
length(y) == n || throw(
541+
check_initial_params(x::Nothing, n::Int) = nothing
542+
function check_initial_params(x, n::Int)
543+
length(x) == n || throw(
551544
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
552545
)
553-
return y
554-
end
555-
_first_or_nothing(::Nothing, ::Int) = nothing
556-
557-
# `first(x, n::Int)` requires Julia 1.6
558-
function _first(x, n::Int)
559-
@static if VERSION >= v"1.6.0-DEV.431"
560-
first(x, n)
561-
else
562-
if x isa AbstractVector
563-
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
564-
else
565-
collect(Iterators.take(x, n))
566-
end
567-
end
568546
end

0 commit comments

Comments
 (0)