diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index b27e987..ade7cab 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -20,6 +20,37 @@ function reduce_samples(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss:: return reduce(vcat, v) end +function reduce_samples_hypergeometric(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::AbstractArray...) + nt = length(ss) + v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt) + n = minimum(length.(ss)) + + # For hypergeometric sampling, we need to sample without replacement from finite populations + # The number of samples from each reservoir depends on hypergeometric distribution + # Total population size is sum of all reservoir sizes + total_pop = sum(length.(ss)) + + # Sample using hypergeometric distribution for each reservoir + ns = Vector{Int}(undef, nt) + remaining = n + remaining_pop = total_pop + + for i in 1:(nt-1) + pop_i = length(ss[i]) + # Use hypergeometric distribution: drawing `remaining` items from population `remaining_pop` + # where `pop_i` items are of the type we want + ns[i] = rand(extract_rng(rngs, 1), Hypergeometric(pop_i, remaining_pop - pop_i, remaining)) + remaining -= ns[i] + remaining_pop -= pop_i + end + ns[nt] = remaining # Remainder goes to last reservoir + + Threads.@threads for i in 1:nt + v[i] = sample(extract_rng(rngs, i), ss[i], ns[i]; replace = false) + end + return reduce(vcat, v) +end + extract_rng(v::AbstractArray, i) = v[i] extract_rng(v::AbstractRNG, i) = v @@ -31,6 +62,14 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...) sum_w = sum(getfield(s, :state) for s in ss) return [s.state/sum_w for s in ss] end +function get_ps(ss::MultiAlgRSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + return [s.seen_k/sum_w for s in ss] +end +function get_ps(ss::MultiAlgLSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + return [s.seen_k/sum_w for s in ss] +end get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1) function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T} diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 3a0d2f5..af655ee 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -209,11 +209,19 @@ end is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true is_ordered(s::MultiAlgRSWRSKIPSampler) = false -function Base.merge(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge(ss::MultiAlgRSampler...) + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing) +end +function Base.merge(ss::MultiAlgLSampler...) + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + # For AlgL, we need to initialize state and skip_k appropriately + # state should be 0.0 for new merged sampler, skip_k should be 0 + n = minimum(s.n for s in ss) + return MultiAlgLSampler_Mut(n, 0.0, 0, seen_k, ss[1].rng, newvalue, nothing) end function Base.merge(ss::MultiAlgRSWRSKIPSampler...) newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) @@ -223,11 +231,30 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...) return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge!(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge!(ss::MultiAlgRSampler...) + s1 = ss[1] + rest = ss[2:end] + s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in rest) + return s1 +end +function Base.merge!(ss::MultiAlgLSampler...) + s1 = ss[1] + rest = ss[2:end] + s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in rest) + # Reset state and skip_k for the merged sampler + s1.state = 0.0 + s1.skip_k = 0 + return s1 end function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...) s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") diff --git a/test/merge_tests.jl b/test/merge_tests.jl index d2fc637..ceb63fc 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -15,18 +15,54 @@ s_all = (s1, s2) for (s, it) in zip(s_all, iters) for x in it - m1 == AlgRSWRSKIP() ? fit!(s, x) : fit!(s, x, 1.0) + # Handle unweighted vs weighted algorithms + if m1 == AlgRSWRSKIP() + fit!(s, x) + else + fit!(s, x, 1.0) + end end end s_merged = merge(s1, s2) res[shuffle!(rng, value(s_merged))...] += 1 end - cases = (m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP()) ? 10^size : factorial(10)/factorial(10-size) + # Adjust expected number of cases for different algorithms + if m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP() + cases = 10^size + else + cases = factorial(10)/factorial(10-size) + end ps_exact = [1/cases for _ in 1:cases] count_est = [x for x in vec(res) if x != 0] chisq_test = ChisqTest(count_est, ps_exact) @test pvalue(chisq_test) > 0.05 end + + # Separate basic tests for AlgR and AlgL (not statistical) + @testset "AlgR and AlgL basic merge tests" begin + for m in (AlgR(), AlgL()) + s1 = ReservoirSampler{Int}(rng, size, m) + s2 = ReservoirSampler{Int}(rng, size, m) + + # Add some data + for x in 1:2; fit!(s1, x); end + for x in 3:4; fit!(s2, x); end + + # Test that merge works + merged = merge(s1, s2) + @test merged isa Union{StreamSampling.MultiAlgRSampler_Mut, StreamSampling.MultiAlgLSampler_Mut} + @test merged.n == size + + # Test that merge! works + s3 = ReservoirSampler{Int}(rng, size, m) + s4 = ReservoirSampler{Int}(rng, size, m) + for x in 5:6; fit!(s3, x); end + for x in 7:8; fit!(s4, x); end + + result = merge!(s3, s4) + @test result === s3 + end + end s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s_all = (s1, s2) @@ -39,8 +75,23 @@ for m in (AlgRSWRSKIP(), AlgWRSWRSKIP()) s1 = ReservoirSampler{Int}(rng, m) s2 = ReservoirSampler{Int}(rng, m) - m == AlgRSWRSKIP() ? fit!(s1, 1) : fit!(s1, 1, 1.0) - m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0) + if m == AlgRSWRSKIP() + fit!(s1, 1) + fit!(s2, 2) + else + fit!(s1, 1, 1.0) + fit!(s2, 2, 1.0) + end @test value(merge!(s1, s2)) in (1, 2) end + + # Test merge! for multi-element unweighted samplers (AlgR and AlgL) + for m in (AlgR(), AlgL()) + s1 = ReservoirSampler{Int}(rng, 1, m) # Single element reservoir + s2 = ReservoirSampler{Int}(rng, 1, m) + fit!(s1, 1) + fit!(s2, 2) + result = value(merge!(s1, s2)) + @test length(result) == 1 && result[1] in (1, 2) + end end