diff --git a/Project.toml b/Project.toml index 0051650a..c7e78589 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,10 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "11.2.3" [deps] +AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -22,8 +24,10 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2Ext = "JLD2" [compat] +AcceleratedKernels = "0.4" Adapt = "4.0" GPUArraysCore = "= 0.2.0" +GPUToolbox = "0.2, 0.3" JLD2 = "0.4, 0.5" KernelAbstractions = "0.9.28" LLVM = "3.9, 4, 5, 6, 7, 8, 9" diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 8c1fc14e..07b9cdb0 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -1,5 +1,6 @@ module GPUArrays +using GPUToolbox using KernelAbstractions using Serialization using Random @@ -15,7 +16,8 @@ using LLVM.Interop using Reexport @reexport using GPUArraysCore -using KernelAbstractions +import KernelAbstractions as KA +import AcceleratedKernels as AK # device functionality include("device/abstractarray.jl") @@ -31,6 +33,7 @@ include("host/mapreduce.jl") include("host/linalg.jl") include("host/math.jl") include("host/random.jl") +include("host/sorting.jl") include("host/quirks.jl") include("host/uniformscaling.jl") include("host/statistics.jl") diff --git a/src/host/sorting.jl b/src/host/sorting.jl new file mode 100644 index 00000000..a2a7c570 --- /dev/null +++ b/src/host/sorting.jl @@ -0,0 +1,62 @@ + +abstract type SortingAlgorithm end +struct MergeSortAlg <: SortingAlgorithm end + +const MergeSort = MergeSortAlg() + + +function Base.sort!(c::AnyGPUVector, alg::MergeSortAlg; lt=isless, by=identity, rev=false) + # for reverse sorting, invert the less-than function + if rev + lt = !lt + end + + AK.merge_sort!(c; lt, by) + return c +end + +function Base.sort!(c::AnyGPUArray; alg::SortingAlgorithm = MergeSort, kwargs...) + return sort!(c, alg; kwargs...) +end + +function Base.sort(c::AnyGPUArray; kwargs...) + return sort!(copy(c); kwargs...) +end + +function Base.partialsort!(c::AnyGPUVector, k::Union{Integer, OrdinalRange}, + alg::MergeSortAlg; lt=isless, by=identity, rev=false) + + sort!(c, alg; lt, by, rev) + return @allowscalar copy(c[k]) +end + +function Base.partialsort!(c::AnyGPUArray, k::Union{Integer, OrdinalRange}; + alg::SortingAlgorithm=MergeSort, kwargs...) + return partialsort!(c, k, alg; kwargs...) +end + +function Base.partialsort(c::AnyGPUArray, k::Union{Integer, OrdinalRange}; kwargs...) + return partialsort!(copy(c), k; kwargs...) +end + +function Base.sortperm!(ix::AnyGPUArray, A::AnyGPUArray; initialized=false, dims=nothing, kwargs...) + if axes(ix) != axes(A) + throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))")) + end + if !isnothing(dims) + throw(ArgumentError("GPUArrays sort with `dims` kwarg not yet implemented.")) + end + + AK.merge_sortperm!(ix, A; kwargs...) + return ix +end + +function Base.sortperm(c::AnyGPUVector; initialized=false, kwargs...) + AK.merge_sortperm!(KA.allocate(get_backend(c), Int, length(c)), c; kwargs...) +end + +function Base.sortperm(c::AnyGPUArray; dims, kwargs...) + # Base errors for Matrices without dims arg, we should too + error("GPU sort with `dims` kwarg not yet implemented.") + # sortperm!(reshape(adapt(get_backend(c), collect(1:length(c))), size(c)), c; initialized=true, dims, kwargs...) +end diff --git a/test/testsuite.jl b/test/testsuite.jl index b48d7ccd..4cef234d 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -94,6 +94,7 @@ include("testsuite/broadcasting.jl") include("testsuite/linalg.jl") include("testsuite/math.jl") include("testsuite/random.jl") +include("testsuite/sorting.jl") include("testsuite/uniformscaling.jl") include("testsuite/statistics.jl") include("testsuite/alloc_cache.jl") diff --git a/test/testsuite/sorting.jl b/test/testsuite/sorting.jl new file mode 100644 index 00000000..386f518d --- /dev/null +++ b/test/testsuite/sorting.jl @@ -0,0 +1,51 @@ +@testsuite "sorting/sort" (AT, eltypes)->begin + # Fuzzy correctness testing + @testset "$ET" for ET in filter(x -> x <: Real, eltypes) + for _ in 1:10 + num_elems = rand(1:100_000) + @test compare((A)->Base.sort!(A), AT, rand(ET, num_elems)) + end + # Not yet implemented + # for _ in 1:5 + # size = rand(1:100, 2) + # @test compare((A)->Base.sort!(A; dims=1), AT, rand(ET, size...)) + # @test compare((A)->Base.sort!(A; dims=2), AT, rand(ET, size...)) + # end + end +end + +@testsuite "sorting/sortperm" (AT, eltypes)->begin + # Fuzzy correctness testing + @testset "$ET" for ET in filter(x -> x <: Real, eltypes) + for _ in 1:10 + num_elems = rand(1:100_000) + @test compare((ix, A)->Base.sortperm!(ix, A), AT, zeros(Int32, num_elems), rand(ET, num_elems)) + end + # Not yet implemented + # for _ in 1:5 + # size = rand(1:100, 2) + # @test compare((A)->Base.sort!(A; dims=1), AT, zeros(Int32, size...), rand(ET, size...)) + # @test compare((A)->Base.sort!(A; dims=2), AT, zeros(Int32, size...), rand(ET, size...)) + # end + end +end + +@testsuite "sorting/partialsort" (AT, eltypes)->begin + local N = 10000 + @testset "$ET" for ET in filter(x -> x <: Real, eltypes) + @test compare((A)->Base.partialsort!(A, 1), AT, rand(ET, N)) + @test compare((A)->Base.partialsort!(A, 1; rev=true), AT, rand(ET, N)) + + @test compare((A)->Base.partialsort!(A, N), AT, rand(ET, N)) + @test compare((A)->Base.partialsort!(A, N; rev=true), AT, rand(ET, N)) + + @test compare((A)->Base.partialsort!(A, N÷2), AT, rand(ET, N)) + @test compare((A)->Base.partialsort!(A, N÷2; rev=true), AT, rand(ET, N)) + + @test compare((A)->Base.partialsort!(A, (N÷10):(2N÷10)), AT, rand(ET, N)) + @test compare((A)->Base.partialsort!(A, (N÷10):(2N÷10); rev=true), AT, rand(ET, N)) + + @test compare((A)->Base.partialsort!(A, 1:N), AT, rand(ET, N)) + @test compare((A)->Base.partialsort!(A, 1:N; rev=true), AT, rand(ET, N)) + end +end