@@ -853,13 +853,14 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false) where {T}
853
853
j_final = (1 + k0 - k)
854
854
for j = 1 : j_final
855
855
856
- # use Int32 args for indexing --> ~10% faster kernels
857
- args1 = (c, map (Int32, (c_len, k, j, j_final))... , by, lt, Val (rev))
856
+ # use Int32 args for indexing when possible --> ~10% faster kernels
857
+ I = c_len <= typemax (Int32) ? Int32 : Int
858
+ args1 = (c, map (I, (c_len, k, j, j_final))... , by, lt, Val (rev))
858
859
kernel1 = @cuda launch = false comparator_small_kernel (args1... )
859
860
config1 = launch_configuration (kernel1. fun, shmem = threads -> bitonic_shmem (c, threads))
860
861
threads1 = prevpow (2 , config1. threads)
861
862
862
- args2 = (c, map (Int32 , (c_len, k, j))... , by, lt, Val (rev))
863
+ args2 = (c, map (I , (c_len, k, j))... , by, lt, Val (rev))
863
864
kernel2 = @cuda launch = false comparator_kernel (args2... )
864
865
config2 = launch_configuration (kernel2. fun, shmem = threads -> bitonic_shmem (c, threads))
865
866
threads2 = prevpow (2 , config2. threads)
@@ -939,11 +940,17 @@ function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs
939
940
return partialsort! (copy (c), k; kwargs... )
940
941
end
941
942
942
- function Base. sortperm! (I:: AnyCuArray , c:: AnyCuArray ; kwargs... )
943
+ function Base. sortperm! (I:: AnyCuArray{T} , c:: AnyCuArray ; initialized= false , kwargs... ) where T
944
+ if length (I) != length (c)
945
+ throw (ArgumentError (" index vector must have the same length/indices as the source vector" ))
946
+ end
947
+ if ! initialized
948
+ I .= one (T): T (length (I))
949
+ end
943
950
bitonic_sort! ((c, I); kwargs... )
944
951
return I
945
952
end
946
953
947
954
function Base. sortperm (c:: AnyCuArray ; kwargs... )
948
- sortperm! (CuArray (1 : length (c)), c; kwargs... )
955
+ sortperm! (CuArray (1 : length (c)), c; initialized = true , kwargs... )
949
956
end
0 commit comments