|
| 1 | + |
| 2 | +using StaticArrays |
| 3 | +using SyncBarriers |
| 4 | +using BenchmarkTools |
| 5 | +import AcceleratedKernels as AK |
| 6 | + |
| 7 | + |
| 8 | +using AllocCheck |
| 9 | + |
| 10 | +using Random |
| 11 | +Random.seed!(0) |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | +# @check_allocs ignore_throw=false |
| 17 | +function _sample_sort_histogram!(v, splitters, histograms, itask, irange) |
| 18 | + for i in irange |
| 19 | + ibucket = 1 + AK._searchsortedlast(splitters, v[i], 1, length(splitters), isless) |
| 20 | + histograms[ibucket, itask] += 1 |
| 21 | + end |
| 22 | + nothing |
| 23 | +end |
| 24 | + |
| 25 | + |
| 26 | +# @check_allocs ignore_throw=false |
| 27 | +function _sample_sort_parallel!( |
| 28 | + v, dest, comp, |
| 29 | + splitters, histograms, |
| 30 | + max_tasks, |
| 31 | +) |
| 32 | + # Compute the histogram for each task |
| 33 | + AK.itask_partition(length(v), max_tasks, 1) do itask, irange |
| 34 | + _sample_sort_histogram!( |
| 35 | + v, |
| 36 | + splitters, histograms, |
| 37 | + itask, irange, |
| 38 | + ) |
| 39 | + end |
| 40 | + nothing |
| 41 | +end |
| 42 | + |
| 43 | + |
| 44 | + |
| 45 | +function sample_sort!( |
| 46 | + v; |
| 47 | + max_tasks=Threads.nthreads(), |
| 48 | + |
| 49 | + lt=isless, |
| 50 | + by=identity, |
| 51 | + rev::Union{Bool, Nothing}=nothing, |
| 52 | + order::Base.Order.Ordering=Base.Order.Forward, |
| 53 | + |
| 54 | + temp=nothing |
| 55 | +) |
| 56 | + |
| 57 | + oversampling_factor = 4 |
| 58 | + num_elements = length(v) |
| 59 | + |
| 60 | + if num_elements < 2 |
| 61 | + return v |
| 62 | + end |
| 63 | + |
| 64 | + if max_tasks == 1 || num_elements < oversampling_factor * max_tasks |
| 65 | + return sort!(v, lt=lt, by=by, rev=rev, order=order) |
| 66 | + end |
| 67 | + |
| 68 | + # Create a temporary buffer for the sorted output |
| 69 | + if temp === nothing |
| 70 | + dest = similar(v) |
| 71 | + else |
| 72 | + # TODO add checks |
| 73 | + dest = temp |
| 74 | + end |
| 75 | + |
| 76 | + # Construct comparator |
| 77 | + ord = Base.Order.ord(lt, by, rev, order) |
| 78 | + comp = (x, y) -> Base.Order.lt(ord, x, y) |
| 79 | + |
| 80 | + # Take equally spaced samples, save them in dest |
| 81 | + num_samples = oversampling_factor * max_tasks |
| 82 | + isamples = IntLinSpace(1, num_elements, num_samples) |
| 83 | + @inbounds for i in 1:num_samples |
| 84 | + dest[i] = v[isamples[i]] |
| 85 | + end |
| 86 | + |
| 87 | + # Sort samples and choose splitters |
| 88 | + sort!(view(dest, 1:num_samples), lt=lt, by=by, rev=rev, order=order) |
| 89 | + splitters = Vector{eltype(v)}(undef, max_tasks - 1) |
| 90 | + for i in 1:(max_tasks - 1) |
| 91 | + splitters[i] = dest[div(i * num_samples, max_tasks)] |
| 92 | + end |
| 93 | + |
| 94 | + # Pre-allocate histogram for each task; each column is exclusive to the task |
| 95 | + histograms = zeros(Int, max_tasks + 8, max_tasks) # Add padding to avoid false sharing |
| 96 | + |
| 97 | + # Run threaded region |
| 98 | + _sample_sort_parallel!( |
| 99 | + v, dest, comp, |
| 100 | + splitters, histograms, |
| 101 | + max_tasks, |
| 102 | + ) |
| 103 | + |
| 104 | + dest |
| 105 | +end |
| 106 | + |
| 107 | + |
| 108 | + |
| 109 | + |
| 110 | + |
| 111 | +# Utilities |
| 112 | + |
| 113 | + |
| 114 | +# Create an integer linear space between start and stop on demand |
| 115 | +struct IntLinSpace{T <: Integer} |
| 116 | + start::T |
| 117 | + stop::T |
| 118 | + length::T |
| 119 | +end |
| 120 | + |
| 121 | +function IntLinSpace(start::Integer, stop::Integer, length::Integer) |
| 122 | + start <= stop || throw(ArgumentError("`start` must be <= `stop`")) |
| 123 | + length >= 2 || throw(ArgumentError("`length` must be >= 2")) |
| 124 | + |
| 125 | + IntLinSpace{typeof(start)}(start, stop, length) |
| 126 | +end |
| 127 | + |
| 128 | +Base.IndexStyle(::IntLinSpace) = IndexLinear() |
| 129 | +Base.length(ils::IntLinSpace) = ils.length |
| 130 | + |
| 131 | +Base.firstindex(::IntLinSpace) = 1 |
| 132 | +Base.lastindex(ils::IntLinSpace) = ils.length |
| 133 | + |
| 134 | +function Base.getindex(ils::IntLinSpace, i) |
| 135 | + @boundscheck 1 <= i <= ils.length || throw(BoundsError(ils, i)) |
| 136 | + |
| 137 | + if i == 1 |
| 138 | + ils.start |
| 139 | + elseif i == length |
| 140 | + ils.stop |
| 141 | + else |
| 142 | + ils.start + div((i - 1) * (ils.stop - ils.start), ils.length - 1, RoundUp) |
| 143 | + end |
| 144 | +end |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | + |
| 152 | + |
| 153 | +v = rand(Float32, 100_000) |
| 154 | + |
| 155 | +try |
| 156 | + temp = sample_sort!(v) |
| 157 | +catch e |
| 158 | + display(e.errors[1]) |
| 159 | + rethrow(e) |
| 160 | +end |
| 161 | + |
| 162 | + |
| 163 | +t = @timed sample_sort!(v) |
| 164 | + |
| 165 | + |
| 166 | +# @assert issorted(temp) |
| 167 | +# println("sorted") |
| 168 | + |
| 169 | + |
| 170 | +# display(@benchmark sort!(v) setup=(v=rand(Float64, 10_000_000))) |
| 171 | +display(@benchmark sample_sort!(v) setup=(v=rand(Float64, 100_000))) |
0 commit comments