Skip to content

Commit 991f3ae

Browse files
committed
alg argument for sort vectors
1 parent bf45a4b commit 991f3ae

File tree

2 files changed

+75
-31
lines changed

2 files changed

+75
-31
lines changed

src/sorting.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -887,28 +887,44 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false) where {T}
887887
end
888888
end
889889

890+
abstract type CuSortingAlgorithm end
891+
struct CuQuickSortAlg <: CuSortingAlgorithm end
892+
struct CuBitonicSortAlg <: CuSortingAlgorithm end
893+
const CuQuickSort = CuQuickSortAlg()
894+
const CuBitonicSort = CuBitonicSortAlg()
895+
export CuQuickSort, CuBitonicSort
896+
890897
# Base interface implementation
891898

892899
using .BitonicSort
893900
using .Quicksort
894901

895-
function Base.sort!(c::AnyCuArray; dims::Integer, lt=isless, by=identity, rev=false)
902+
903+
function Base.sort!(c::AnyCuVector, alg::CuQuickSortAlg; lt=isless, by=identity, rev=false)
896904
# for reverse sorting, invert the less-than function
897905
if rev
898906
lt = !lt
899907
end
900908

901-
quicksort!(c; lt, by, dims)
909+
quicksort!(c; lt, by, dims=1)
902910
return c
903911
end
904912

905-
function Base.sort!(c::AnyCuVector; lt=isless, by=identity, rev=false)
906-
# for reverse sorting, invert the less-than function
913+
function Base.sort!(c::AnyCuVector, alg::CuBitonicSortAlg; kwargs...)
914+
return bitonic_sort!(c; kwargs...)
915+
end
916+
917+
function Base.sort!(c::AnyCuVector; alg :: CuSortingAlgorithm = CuBitonicSort, kwargs...)
918+
return sort!(c, alg; kwargs...)
919+
end
920+
921+
function Base.sort!(c::AnyCuArray; dims::Integer, lt=isless, by=identity, rev=false)
922+
# for multi dim sorting, only quicksort is supported so no alg keyword
907923
if rev
908924
lt = !lt
909925
end
910926

911-
quicksort!(c; lt, by, dims=1)
927+
quicksort!(c; lt, by, dims)
912928
return c
913929
end
914930

test/sorting.jl

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ end
169169
"""
170170
Tests if `c` is a valid sort of `a`
171171
"""
172-
function check_equivalence(a::Vector, c::Vector; kwargs...)
172+
function check_equivalence(a::Vector, c::Vector; alg=nothing, kwargs...)
173173
counter(a) == counter(c) && issorted(c; kwargs...)
174174
end
175175

@@ -281,58 +281,86 @@ end
281281
end
282282

283283
@testset "interface" begin
284+
@testset "quicksort" begin
284285
# pre-sorted
285-
@test check_sort!(Int, 1000000)
286-
@test check_sort!(Int32, 1000000)
287-
@test check_sort!(Float64, 1000000)
288-
@test check_sort!(Float32, 1000000)
286+
@test check_sort!(Int, 1000000; alg=CuQuickSort)
287+
@test check_sort!(Int32, 1000000; alg=CuQuickSort)
288+
@test check_sort!(Float64, 1000000; alg=CuQuickSort)
289+
@test check_sort!(Float32, 1000000; alg=CuQuickSort)
289290
@test check_sort!(Int32, 1000000; rev=true)
290291
@test check_sort!(Float32, 1000000; rev=true)
291292

292293
# reverse sorted
293-
@test check_sort!(Int32, 1000000, x -> -x)
294-
@test check_sort!(Float32, 1000000, x -> -x)
295-
@test check_sort!(Int32, 1000000, x -> -x; rev=true)
296-
@test check_sort!(Float32, 1000000, x -> -x; rev=true)
297-
298-
@test check_sort!(Int, 10000, x -> rand(Int))
299-
@test check_sort!(Int32, 10000, x -> rand(Int32))
300-
@test check_sort!(Int8, 10000, x -> rand(Int8))
301-
@test check_sort!(Float64, 10000, x -> rand(Float64))
302-
@test check_sort!(Float32, 10000, x -> rand(Float32))
303-
@test check_sort!(Float16, 10000, x -> rand(Float16))
294+
@test check_sort!(Int32, 1000000, x -> -x; alg=CuQuickSort)
295+
@test check_sort!(Float32, 1000000, x -> -x; alg=CuQuickSort)
296+
@test check_sort!(Int32, 1000000, x -> -x; rev=true, alg=CuQuickSort)
297+
@test check_sort!(Float32, 1000000, x -> -x; rev=true, alg=CuQuickSort)
298+
299+
@test check_sort!(Int, 10000, x -> rand(Int); alg=CuQuickSort)
300+
@test check_sort!(Int32, 10000, x -> rand(Int32); alg=CuQuickSort)
301+
@test check_sort!(Int8, 10000, x -> rand(Int8); alg=CuQuickSort)
302+
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CuQuickSort)
303+
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CuQuickSort)
304+
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CuQuickSort)
304305

305306
# non-uniform distributions
306-
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 2))
307-
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 3))
307+
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 2); alg=CuQuickSort)
308+
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 3); alg=CuQuickSort)
308309

309310
# more copies of each value than can fit in one block
310-
@test check_sort!(Int8, 4000000, x -> rand(Int8))
311+
@test check_sort!(Int8, 4000000, x -> rand(Int8); alg=CuQuickSort)
311312

312313
# multiple dimensions
313314
@test check_sort!(Int32, (4, 50000, 4); dims=2)
314315
@test check_sort!(Int32, (4, 4, 50000); dims=3, rev=true)
315316

316317
# large sizes
317-
@test check_sort!(Float32, 2^25)
318+
@test check_sort!(Float32, 2^25; alg=CuQuickSort)
318319

319320
# various sync depths
320321
for depth in 0:4
321322
CUDA.limit!(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH, depth)
322-
@test check_sort!(Int, 100000, x -> rand(Int))
323+
@test check_sort!(Int, 100000, x -> rand(Int); alg=CuQuickSort)
323324
end
324325

325326
# using a `by` argument
326-
@test check_sort(Float32, 100000; by=x->abs(x - 0.5))
327+
@test check_sort(Float32, 100000; by=x->abs(x - 0.5), alg=CuQuickSort)
327328
@test check_sort!(Float32, (100000, 4); by=x->abs(x - 0.5), dims=1)
328329
@test check_sort!(Float32, (4, 100000); by=x->abs(x - 0.5), dims=2)
329-
@test check_sort!(Float64, 400000; by=x->8*x-round(8*x))
330+
@test check_sort!(Float64, 400000; by=x->8*x-round(8*x), alg=CuQuickSort)
330331
@test check_sort!(Float64, (100000, 4); by=x->8*x-round(8*x), dims=1)
331332
@test check_sort!(Float64, (4, 100000); by=x->8*x-round(8*x), dims=2)
332333
# target bubble sort by using sub-blocksize input:
333-
@test check_sort!(Int, 200; by=x->x % 2)
334-
@test check_sort!(Int, 200; by=x->x % 3)
335-
@test check_sort!(Int, 200; by=x->x % 4)
334+
@test check_sort!(Int, 200; by=x->x % 2, alg=CuQuickSort)
335+
@test check_sort!(Int, 200; by=x->x % 3, alg=CuQuickSort)
336+
@test check_sort!(Int, 200; by=x->x % 4, alg=CuQuickSort)
337+
end # end quicksort tests
338+
339+
@testset "bitonic sort" begin
340+
# test various types
341+
@test check_sort!(Int, 10000, x -> rand(Int); alg=CuBitonicSort)
342+
@test check_sort!(Int32, 10000, x -> rand(Int32); alg=CuBitonicSort)
343+
@test check_sort!(Int8, 10000, x -> rand(Int8); alg=CuBitonicSort)
344+
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CuBitonicSort)
345+
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CuBitonicSort)
346+
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CuBitonicSort)
347+
348+
# test various sizes
349+
@test check_sort!(Float32, 1, x -> rand(Float32); alg=CuBitonicSort)
350+
@test check_sort!(Float32, 2, x -> rand(Float32); alg=CuBitonicSort)
351+
@test check_sort!(Float32, 3, x -> rand(Float32); alg=CuBitonicSort)
352+
@test check_sort!(Float32, 4, x -> rand(Float32); alg=CuBitonicSort)
353+
@test check_sort!(Float32, 1 << 16 + 0, x -> rand(Float32); alg=CuBitonicSort)
354+
@test check_sort!(Float32, 1 << 16 + 1, x -> rand(Float32); alg=CuBitonicSort)
355+
@test check_sort!(Float32, 1 << 16 + 31, x -> rand(Float32); alg=CuBitonicSort)
356+
@test check_sort!(Float32, 1 << 16 + 32, x -> rand(Float32); alg=CuBitonicSort)
357+
@test check_sort!(Float32, 1 << 16 + 33, x -> rand(Float32); alg=CuBitonicSort)
358+
@test check_sort!(Float32, 1 << 16 + 127, x -> rand(Float32); alg=CuBitonicSort)
359+
@test check_sort!(Float32, 1 << 16 + 128, x -> rand(Float32); alg=CuBitonicSort)
360+
@test check_sort!(Float32, 1 << 16 + 129, x -> rand(Float32); alg=CuBitonicSort)
361+
end # end bitonic tests
362+
363+
@test_throws MethodError check_sort!(Int, (100, 100); alg=CuBitonicSort, dims=1)
336364

337365
#partial sort
338366
@test check_partialsort!(Int, 100000, 1)

0 commit comments

Comments
 (0)