diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index 4f37cfa..527f030 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -197,6 +197,10 @@ struct SequentialSampler{S} end Base.iterate(s::SequentialSampler) = iterate(s.s) Base.iterate(s::SequentialSampler, state) = iterate(s.s, state) +Base.IteratorEltype(::SequentialSampler) = Base.HasEltype() +Base.eltype(::SequentialSampler) = Int +Base.IteratorSize(::SequentialSampler) = Base.HasLength() +Base.length(s::SequentialSampler) = s.s.n """ itsample([rng], iter, method = AlgRSWRSKIP()) diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index b27e987..045dfa7 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -11,11 +11,18 @@ function reduce_samples(t::Union{TypeS,TypeUnion}, ss::BinaryHeap...) end function reduce_samples(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::AbstractArray...) nt = length(ss) - v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt) + T = get_type_rs(t, ss...) + v = Vector{Vector{T}}(undef, nt) n = minimum(length.(ss)) ns = rand(extract_rng(rngs, 1), Multinomial(n, ps)) Threads.@threads for i in 1:nt - v[i] = sample(extract_rng(rngs, i), ss[i], ns[i]; replace = false) + s = ss[i] + vi = Vector{T}(undef, ns[i]) + @inbounds for (q, j) in enumerate(SequentialSampler(extract_rng(rngs, i), + ns[i], length(s), AlgHiddenShuffle())) + vi[q] = s[j] + end + v[i] = vi end return reduce(vcat, v) end