diff --git a/src/Primes.jl b/src/Primes.jl index 2a0e075..83e07ef 100644 --- a/src/Primes.jl +++ b/src/Primes.jl @@ -9,9 +9,11 @@ using Base: BitSigned using Base.Checked: checked_neg export isprime, primes, primesmask, factor, ismersenneprime, isrieselprime, - nextprime, nextprimes, prevprime, prevprimes, prime, prodfactors, radical, totient + nextprime, nextprimes, prevprime, prevprimes, prime, prodfactors, radical, totient, + SegmentedSieve include("factorization.jl") +include("segmented_sieve/SegmentedSieve.jl") # Primes generating functions # https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes diff --git a/src/segmented_sieve/SegmentedSieve.jl b/src/segmented_sieve/SegmentedSieve.jl new file mode 100644 index 0000000..201c6c8 --- /dev/null +++ b/src/segmented_sieve/SegmentedSieve.jl @@ -0,0 +1,44 @@ +module SegmentedSieve + +const ps = (1, 7, 11, 13, 17, 19, 23, 29) + +""" +Population count of a vector of UInt8s for counting prime numbers. +See https://github.com/JuliaLang/julia/issues/34059 +""" +function vec_count_ones(xs::Union{Vector{UInt8}, Base.FastContiguousSubArray{UInt8}}) + n = length(xs) + count = 0 + chunks = n ÷ sizeof(UInt) + GC.@preserve xs begin + ptr = Ptr{UInt}(pointer(xs)) + for i in 1:chunks + count += count_ones(unsafe_load(ptr, i)) + end + end + + @inbounds for i in 8chunks+1:n + count += count_ones(xs[i]) + end + + count +end + +function to_idx(x) + x == 1 && return 1 + x == 7 && return 2 + x == 11 && return 3 + x == 13 && return 4 + x == 17 && return 5 + x == 19 && return 6 + x == 23 && return 7 + return 8 +end + +include("generate_sieving_loop.jl") +include("sieve_small.jl") +include("presieve.jl") +include("siever.jl") +include("sieve.jl") + +end # module diff --git a/src/segmented_sieve/generate_sieving_loop.jl b/src/segmented_sieve/generate_sieving_loop.jl new file mode 100644 index 0000000..689a0c4 --- /dev/null +++ b/src/segmented_sieve/generate_sieving_loop.jl @@ -0,0 +1,145 @@ +""" +For a prime number p and a multiple q, the wheel index encodes the prime number index of +p and q modulo 30, which can be encoded to a single number from 1 ... 64. This allows us +to jump immediately into the correct loop at the correct offset. +""" +create_jump(wheel_index, i) = :($wheel_index === $i && @goto $(Symbol(:x, i))) + +create_label(wheel_index) = :(@label $(Symbol(:x, wheel_index))) + +wheel_mask(prime_mod_30)::UInt8 = ~(0x01 << (to_idx(prime_mod_30) - 1)) + +""" +For any prime number `p` we compute its prime number index modulo 30 (here `wheel`) and we +generate the loop that crosses of the next 8 multiples that, modulo 30, are +p * {1, 7, 11, 13, 17, 19, 23, 29}. +""" +function unrolled_loop(p_idx) + p = ps[p_idx] + + # First push the stopping criterion + unrolled_loop_body = Any[:(byte_idx > unrolled_max && break)] + + # Cross off the 8 next multiples + for q in ps + div, rem = divrem(p * q, 30) + bit = wheel_mask(rem) + push!(unrolled_loop_body, :(xs[byte_idx + increment * $(q - 1) + $div] &= $bit)) + end + + # Increment the byte index to where the next / 9th multiple is located + push!(unrolled_loop_body, :(byte_idx += increment * 30 + $p)) + + quote + while true + $(unrolled_loop_body...) + end + end +end + +""" +The fan-in / fan-out phase that crosses off one multiple and then checks bounds; this is +before and after the unrolled loop starts and finishes respectively. +""" +function single_loop_item_not_unrolled(p_idx, q_idx, save_on_exit = true) + # Our prime number modulo 30 + p = ps[p_idx] + + ps_next = (1, 7, 11, 13, 17, 19, 23, 29, 31) + + # Label name + jump_idx = 8 * (p_idx - 1) + q_idx + + # Current and next multiplier modulo 30 + q_curr, q_next = ps_next[q_idx], ps_next[q_idx + 1] + + # Get the bit mask for crossing off p * q_curr + div_curr, rem_curr = divrem(p * q_curr, 30) + bit = wheel_mask(rem_curr) + + # Compute the increments for the byte index for the next multiple + incr_bytes = p * q_next ÷ 30 - div_curr + incr_multiple = q_next - q_curr + + quote + # Todo: this generates an extra jump, maybe conditional moves are possible? + if byte_idx > n_bytes + + # For a segmented sieve we store where we exit the loop, since that is the + # entrypoint in the next loop; to avoid modulo computation to find the offset + $(save_on_exit ? :(last_idx = $jump_idx) : nothing) + @goto out + end + + # Cross off the multiple + xs[byte_idx] &= $bit + + # Increment the byte index to where the next multiple is located + byte_idx += increment * $incr_multiple + $incr_bytes + end +end + +""" +Full loop generates a potentially unrolled loop for a particular wheel +that may or may not save the exit point. +""" +function full_loop_for_wheel(wheel, unroll = true, save_on_exit = true) + loop_statements = [] + + for i = 1 : 8 + push!(loop_statements, create_label(8 * (wheel - 1) + i)) + unroll && i == 1 && push!(loop_statements, unrolled_loop(wheel)) + push!(loop_statements, single_loop_item_not_unrolled(wheel, i, save_on_exit)) + end + + quote + while true + $(loop_statements...) + end + end +end + +""" +Generates a sieving loop that crosses off multiples of a given prime number. + + @sieve_loop :unroll :save_on_exit + @sieve_loop +""" +macro sieve_loop(options...) + unroll, save_on_exit = :(:unroll) in options, :(:save_on_exit) in options + + # When crossing off p * q where `p` is the siever prime and `q` the current multiplier + # we have that p and q are {1, 7, 11, 13, 17, 19, 23, 29} mod 30. + # For each of these 8 possibilities for `p` we create a loop, and per loop we + # create 8 entrypoints to jump into. The first entrypoint is the unrolled loop for + # whenever we can remove 8 multiples at the same time when all 8 fit in the interval + # between byte_start:byte_next_start-1. Otherwise we can only remove one multiple at + # a time. With 8 loops and 8 entrypoints per loop we have 64 different labels, numbered + # x1 ... x64. + + # As an example, take p = 7 as a prime number and q = 23 as the first multiplier, and + # assume our number line starts at 1 (so byte 1 represents 1:30, byte 2 represent 31:60). + # We have to cross off 7 * 23 = 161 first, which has byte index 6. Our prime number `p` + # is in the 2nd spoke of the wheel and q is in the 7th spoke. This means we have to jump + # to the 7th label in the 2nd loop; that is label 8 * (2 - 1) + 7 = 15. There we cross + # off the multiple (since 161 % 30 = 11 is the 3rd spoke, we "and" the byte with 0b11011111) + # Then we move to 7 * 29 (increment the byte index accordingly), cross it off as well. + # And now we enter the unrolled loop where 7 * {31, 37, ..., 59} are crossed off, then + # 7 * {61, 67, ..., 89} etc. Lastly we reach the end of the sieving interval, we cross + # off the remaining multiples one by one, until the byte index is passed the end. + # When that is the case, we save at which multiple / label we exited, so we can jump + # there without computation when the next interval of the number line is sieved. + + esc(quote + $(unroll ? :(unrolled_max = n_bytes - increment * 28 - 28) : nothing) + + # Create jumps inside loops + $([create_jump(:wheel_idx, i) for i = 1 : 64]...) + + # # Create loops + $([full_loop_for_wheel(wheel, unroll, save_on_exit) for wheel in 1 : 8]...) + + # Point of exit + @label out + end) +end \ No newline at end of file diff --git a/src/segmented_sieve/presieve.jl b/src/segmented_sieve/presieve.jl new file mode 100644 index 0000000..36615bf --- /dev/null +++ b/src/segmented_sieve/presieve.jl @@ -0,0 +1,62 @@ +""" +The {2, 3, 5}-wheel is efficient because it compresses memory +perfectly (1 byte = 30 numbers) and removes 4/15th of all the +multiples already. We don't get that memory efficiency when +extending the wheel to {2, 3, 5, 7}, since we need to store +48 bits per 210 numbers, which is could be done with one 64-bit +integer per 210 numbers, which in fact compresses worse +(64 bits / 210 numbers) than the {2, 3, 5}-wheel +(8 bits / 30 numbers). + +What we can do however, is compute the repeating pattern that +the first `n` primes create, and copy that pattern over. That +is, we look at a the numbers modulo p₁ * p₂ * ⋯ * pₙ. + +For instance, when presieving all multiples of {2, 3, ..., 19} +we allocate a buffer for the range 1 : 2 * 3 * ... * 19 = +1:9_699_690. In a {2, 3, 5} wheel this means a buffer of +9_699_690 ÷ 30 = 323_323 bytes. +""" +function create_presieve_buffer() + n_bytes = 7 * 11 * 13 * 17 + xs = fill(0xFF, n_bytes) + + @inbounds for p in (7, 11, 13, 17) + p² = p * p + byte_idx = p² ÷ 30 + 1 + wheel = to_idx(p) + wheel_idx = 8 * (wheel - 1) + wheel + increment = 0 + @sieve_loop :unroll + end + + @inbounds xs[1] = 0b11100001 # remove 7, 11, 13 and 17 + return xs +end + +""" +When applying the presieve buffer, we have to compute the offset in +""" +function apply_presieve_buffer!(xs::Vector{UInt8}, buffer::Vector{UInt8}, byte_start, byte_stop) + + len = byte_stop - byte_start + 1 + + # todo, clean this up a bit. + from_idx = (byte_start - 1) % length(buffer) + 1 + to = min(len, length(buffer) - from_idx + 1) + + # First copy the remainder of buffer at the front + copyto!(view(xs, Base.OneTo(to)), view(buffer, from_idx:from_idx + to - 1)) + from = to + 1 + + # Then copy buffer multiple times + while from + length(buffer) - 1 <= len + copyto!(view(xs, from : from + length(buffer) - 1), buffer) + from += length(buffer) + end + + # And finally copy the remainder of buffer again + copyto!(view(xs, from:len), view(buffer, Base.OneTo(length(from:len)))) + + xs +end \ No newline at end of file diff --git a/src/segmented_sieve/sieve.jl b/src/segmented_sieve/sieve.jl new file mode 100644 index 0000000..5152f5c --- /dev/null +++ b/src/segmented_sieve/sieve.jl @@ -0,0 +1,124 @@ +import Base: iterate + +export SegmentedSieve + +function generate_siever_primes(small_sieve::SmallSieve, segment_lo) + xs = small_sieve.xs + sievers = Vector{Siever}(undef, vec_count_ones(xs)) + j = 0 + @inbounds for i = eachindex(xs) + x = xs[i] + while x != 0x00 + sievers[j += 1] = Siever(compute_prime(x, i), segment_lo) + x &= x - 0x01 + end + end + return sievers +end + +struct SegmentIterator{T<:AbstractUnitRange} + range::T + segment_length::Int + first_byte::Int + last_byte::Int + sievers::Vector{Siever} + presieve_buffer::Vector{UInt8} + segment::Vector{UInt8} +end + +struct Segment{Tr,Ts} + range::Tr + segment::Ts +end + +function Base.show(io::IO, s::Segment) + # compute left padding + padding = floor(Int, log10(last(s.range))) + 1 + + padding_str = " " ^ padding + + print(io, padding_str, " ") + for p in ps + print(io, lpad(p, 2, "0"), " ") + end + + println() + + for (start, byte) in zip(s.range, s.segment) + mask = 0b00000001 + print(lpad(start, padding, "0"), " ") + for i = 1 : 8 + print(io, (byte & mask) == mask ? " x " : " . ") + mask <<= 1 + end + println() + end + + io +end + +function SegmentIterator(range::T, segment_length::Integer) where {T<:AbstractUnitRange} + from, to = first(range), last(range) + first_byte, last_byte = cld(first(range), 30), cld(last(range), 30) + sievers = generate_siever_primes(SmallSieve(isqrt(to)), 30 * (first_byte - 1) + 1) + presieve_buffer = create_presieve_buffer() + xs = zeros(UInt8, segment_length) + + return SegmentIterator{T}(range, segment_length, first_byte, last_byte, sievers, presieve_buffer, xs) +end + +function iterate(iter::SegmentIterator, segment_index_start = iter.first_byte) + @inbounds begin + if segment_index_start ≥ iter.last_byte + return nothing + end + + from, to = first(iter.range), last(iter.range) + + segment_index_next = min(segment_index_start + iter.segment_length, iter.last_byte + 1) + segment_curr_len = segment_index_next - segment_index_start + + # Presieve + apply_presieve_buffer!(iter.segment, iter.presieve_buffer, segment_index_start, segment_index_next - 1) + + # Set the preceding so many bits before `from` to 0 + if segment_index_start == iter.first_byte + if iter.first_byte === 1 + iter.segment[1] = 0b11111110 # just make 1 not a prime. + end + for i = 1 : 8 + 30 * (segment_index_start - 1) + ps[i] >= from && break + iter.segment[1] &= wheel_mask(ps[i]) + end + end + + # Set the remaining so many bits after `to` to 0 + if segment_index_next == iter.last_byte + 1 + for i = 8 : -1 : 1 + to ≥ 30 * (segment_index_next - 2) + ps[i] && break + iter.segment[segment_curr_len] &= wheel_mask(ps[i]) + end + end + + # Sieve the interval, but skip the pre-sieved primes + xs = iter.segment + + for p_idx in 5:length(iter.sievers) + p = iter.sievers[p_idx] + last_idx = 0 + n_bytes = segment_index_next - segment_index_start + byte_idx = p.byte_index - segment_index_start + 1 + wheel_idx = p.wheel_index + increment = p.prime_div_30 + @sieve_loop :unroll :save_on_exit + iter.sievers[p_idx] = Siever(increment, segment_index_start + byte_idx - 1, last_idx) + end + + segment_start = 30 * (segment_index_start - 1) + segment_stop = 30 * (segment_index_next - 1) - 1 + + segment_index_start += iter.segment_length + + return Segment(segment_start:30:segment_stop, view(xs, Base.OneTo(segment_curr_len))), segment_index_start + end +end diff --git a/src/segmented_sieve/sieve_small.jl b/src/segmented_sieve/sieve_small.jl new file mode 100644 index 0000000..e9f7256 --- /dev/null +++ b/src/segmented_sieve/sieve_small.jl @@ -0,0 +1,94 @@ +@inline compute_prime(x::UInt8, i) = 30 * (i - 1) + ps[trailing_zeros(x) + 0x01] + +""" +An iterator for finding small (<= 1_000_000) primes. + + for p in SmallSieve(1_000_000) + println(p) + end + +Uses n ÷ 30 bytes of memory. Skips 2, 3, and 5; starts at 7. +""" +struct SmallSieve + xs::Vector{UInt8} + + function SmallSieve(n::Integer) + # Unrolled loop without segments + n_bytes = cld(n, 30) + xs = fill(0xFF, n_bytes) + + # Ensure `1` is not a prime number + @inbounds xs[1] &= wheel_mask(1) + + # And ensure numbers > n are not prime since we are not considering them + @inbounds for i = 8 : -1 : 1 + n >= 30 * (n_bytes - 1) + ps[i] && break + xs[n_bytes] &= wheel_mask(ps[i]) + end + + hi = isqrt(n) + + @inbounds for i = eachindex(xs) + x = xs[i] + while x != 0x00 + # The next prime number + p = compute_prime(x, i) + + # Are we done yet? + p > hi && @goto done + + # Otherwise cross off multiples of p starting at p² + p² = p * p + byte_idx = p² ÷ 30 + 1 + wheel = to_idx(p % 30) + wheel_idx = 8 * (wheel - 1) + wheel + increment = i - 1 + @sieve_loop :unroll # Just unroll -- no segmented business here + x &= x - 0x01 + end + end + + @label done + + return new(xs) + end +end + +# Yes, this is an O(n) computation, but it should be much faster than computing +# the prime numbers from the bitmask anways +Base.length(s::SmallSieve) = vec_count_ones(s.xs) +Base.eltype(s::SmallSieve) = Int + +@inline function Base.iterate(s::SmallSieve) + @inbounds for i = eachindex(s.xs) + x = s.xs[i] + x !== 0x00 && return compute_prime(x, i), (x & (x - 0x01), i) + end + + return nothing +end + +@inline function Base.iterate(s::SmallSieve, state::Tuple{UInt8, Int}) + x, i = state + @inbounds while true + x !== 0x00 && return compute_prime(x, i), (x & (x - 0x01), i) + i === length(s.xs) && return nothing + x = s.xs[i += 1] + end +end + +function small_primes(n) + # This runs 20% faster than collect(SmallSieve(n)) :( + xs = SmallSieve(n).xs + primes = Vector{Int}(undef, vec_count_ones(xs)) + j = 0 + @inbounds for i = eachindex(xs) + x = xs[i] + while x != 0x00 + primes[j += 1] = compute_prime(x, i) + x &= x - 0x01 + end + end + + primes +end \ No newline at end of file diff --git a/src/segmented_sieve/siever.jl b/src/segmented_sieve/siever.jl new file mode 100644 index 0000000..ebf0309 --- /dev/null +++ b/src/segmented_sieve/siever.jl @@ -0,0 +1,61 @@ +function compute_offset(p::Integer, segment_lo::Integer) + p² = p * p + if p² >= segment_lo + byte = cld(p², 30) + + # Wheel index stores both the index of the prime number mod 30 + # and the index of the active multiple. We start crossing off + # p * p, so that would be to_idx(p % 30) twice. We combine those values + # as a number between 1 ... 64. + wheel = to_idx(p % 30) + wheel_index = 8 * (wheel - 1) + wheel + + return byte, wheel_index + else + # p * q will be the first number to cross off + q = cld(segment_lo, p) + q_quot, q_rem = divrem(q, 30) + + remainders = (1, 7, 11, 13, 17, 19, 23, 29, 31) + i = 1 + while remainders[i] < q_rem + i += 1 + end + + # maybe wrap around + q_rem = i == 9 ? 1 : remainders[i] + q_quot = i == 9 ? q_quot + 1 : q_quot + i = i == 9 ? 1 : i + + # Our actual first acceptable multiple + q = 30q_quot + q_rem + + byte = cld(p * q, 30) + wheel_index = 8 * (to_idx(p % 30) - 1) + i + + return byte, wheel_index + end +end + +struct Siever + prime_div_30::Int + + # byte_index is the integer range 30(byte_index - 1) up to 30byte_index - 1 + byte_index::Int + + # Stores the next prime number to be crossed off. + # If `p` is the prime number and `q` the next multiple to be stored + # Wheel index 8 * to_idx(p % 30) * to_idx(q % 30) + wheel_index::Int + + function Siever(p::Int, segment_lo::Int) + byte, wheel = compute_offset(p, segment_lo) + return new(p ÷ 30, byte, wheel) + end + + Siever(prime_div_30, byte_index, wheel_index) = new(prime_div_30, byte_index, wheel_index) +end + +function Base.show(io::IO, p::Siever) + print(io, 30p.prime_div_30 + ps[(p.wheel_index - 1) ÷ 8 + 1], " (", p.byte_index, ", ", p.wheel_index, ")") +end