@@ -503,10 +503,11 @@ shared memory. It provides a moderate speedup.
503
503
Notation:
504
504
`k`, `j` denote the level of the sorting network (equivalently, recursion depth).
505
505
`vals` is the array of values of type `T` that is either being `sort`-ed or `sortperm`-ed.
506
- `inds` is an array of indices of type `J` that gets permuted in `sortperm!`.
506
+ `inds` is an array of indices of type `J` that gets permuted in `sortperm!` (standard 1-indexed)
507
507
`i1`, `i2` index either `vals` or `inds` depending on the operation.
508
508
`lo`, `n`, and `m` are integers of type `I` used to denote/calculate ranges as
509
- described in the recursive algorithm link above.
509
+ described in the recursive algorithm link above. Note these follow the 0-indexing
510
+ convention from the above source.
510
511
"""
511
512
module BitonicSort
512
513
export bitonic_sort!
729
730
"""
730
731
For sort/sort! `c`, allocate and return shared memory view of `c`
731
732
Each view is indexed along block x dim: one view per pseudo-block
733
+ `index` is expected to be from a 0-indexing context
732
734
"""
733
735
@inline function initialize_shmem! (vals:: AbstractArray{T} , index:: I , in_range, offset = zero (I)) where {T,I}
734
736
swap = CuDynamicSharedArray (T, (blockDim (). x, blockDim (). y), offset)
@@ -741,26 +743,21 @@ end
741
743
742
744
"""
743
745
For sortperm/sortperm!, allocate and return shared memory views of `c` and index
744
- array. Each view is indexed along block x dim: one view per pseudo-block
746
+ array. Each view is indexed along block x dim: one view per pseudo-block.
747
+ `index` is expected to be from a 0-indexing context, but the indices stored in
748
+ `val_inds` are expected to be 1-indexed
745
749
"""
746
750
@inline function initialize_shmem! (vals_inds:: Tuple{AbstractArray{T},AbstractArray{J}} , index, in_range) where {T,J}
747
- # NB: I tried creating both shmem arrays with `initialize_shmem!`
748
- # but the behavior changed - maybe it's necessary to alloc both before
749
- # writing to either?
750
- offset = prod (blockDim ()) * sizeof (T)
751
+ offset = prod (blockDim ()) * sizeof (J)
751
752
vals, inds = vals_inds
752
- swap_vals = CuDynamicSharedArray (T, (blockDim (). x, blockDim (). y))
753
- inds_view = initialize_shmem! (inds, index, in_range, offset)
754
- vals_view = @view swap_vals[:, threadIdx (). y]
755
- if in_range
756
- @inbounds vals_view[threadIdx (). x] = vals[inds_view[threadIdx (). x]]
757
- end
758
- sync_threads ()
753
+ inds_view = initialize_shmem! (inds, index, in_range)
754
+ vals_view = initialize_shmem! (vals, inds_view[threadIdx (). x] - one (J), in_range, offset)
759
755
return vals_view, inds_view
760
756
end
761
757
762
758
"""
763
759
For sort/sort!, copy shmem view `swap` back into global array `c`
760
+ `index` is expected to be from a 0-indexing context
764
761
"""
765
762
@inline function finalize_shmem! (vals:: AbstractArray , swap:: AbstractArray , index:: I , in_range:: Bool ) where {I}
766
763
if in_range
770
767
771
768
"""
772
769
For sortperm/sortperm!, copy shmem view `swap` back to global index array
770
+ `index` is expected to be from a 0-indexing context, but the indices stored in
771
+ `val_inds` are expected to be 1-indexed
773
772
"""
774
773
@inline function finalize_shmem! (vals_inds:: Tuple , swap:: Tuple , index, in_range:: Bool )
775
774
vals, inds = vals_inds
0 commit comments