Skip to content

Commit 214f238

Browse files
committed
improve interface
- add initialized kwarg - safe index typing - argument error if index arr len != val len
1 parent dc39bd8 commit 214f238

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

src/sorting.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -853,13 +853,14 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false) where {T}
853853
j_final = (1 + k0 - k)
854854
for j = 1:j_final
855855

856-
# use Int32 args for indexing --> ~10% faster kernels
857-
args1 = (c, map(Int32, (c_len, k, j, j_final))..., by, lt, Val(rev))
856+
# use Int32 args for indexing when possible --> ~10% faster kernels
857+
I = c_len <= typemax(Int32) ? Int32 : Int
858+
args1 = (c, map(I, (c_len, k, j, j_final))..., by, lt, Val(rev))
858859
kernel1 = @cuda launch = false comparator_small_kernel(args1...)
859860
config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads))
860861
threads1 = prevpow(2, config1.threads)
861862

862-
args2 = (c, map(Int32, (c_len, k, j))..., by, lt, Val(rev))
863+
args2 = (c, map(I, (c_len, k, j))..., by, lt, Val(rev))
863864
kernel2 = @cuda launch = false comparator_kernel(args2...)
864865
config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads))
865866
threads2 = prevpow(2, config2.threads)
@@ -939,11 +940,17 @@ function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs
939940
return partialsort!(copy(c), k; kwargs...)
940941
end
941942

942-
function Base.sortperm!(I::AnyCuArray, c::AnyCuArray; kwargs...)
943+
function Base.sortperm!(I::AnyCuArray{T}, c::AnyCuArray; initialized=false, kwargs...) where T
944+
if length(I) != length(c)
945+
throw(ArgumentError("index vector must have the same length/indices as the source vector"))
946+
end
947+
if !initialized
948+
I .= one(T):T(length(I))
949+
end
943950
bitonic_sort!((c, I); kwargs...)
944951
return I
945952
end
946953

947954
function Base.sortperm(c::AnyCuArray; kwargs...)
948-
sortperm!(CuArray(1:length(c)), c; kwargs...)
955+
sortperm!(CuArray(1:length(c)), c; initialized=true, kwargs...)
949956
end

test/sorting.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,13 @@ end
356356
@test check_sortperm(Float64, 1000000; rev=true)
357357
@test check_sortperm(Float64, 1000000; by=x->abs(x-0.5))
358358
@test check_sortperm(Float64, 1000000; rev=true, by=x->abs(x-0.5))
359-
# check with int32 indices
359+
# check with Int32 indices
360360
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000)
361-
361+
# `initialized` kwarg
362+
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000; initialized=true)
363+
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000; initialized=false)
364+
# expected error case
365+
@test_throws ArgumentError sortperm!(CuArray(1:3), CuArray(1:4))
362366
end
363367

364368
end

0 commit comments

Comments
 (0)