Skip to content

Commit bf45a4b

Browse files
committed
clean up initialize_shmem + docstrings
1 parent 10d29cd commit bf45a4b

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/sorting.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,11 @@ shared memory. It provides a moderate speedup.
503503
Notation:
504504
`k`, `j` denote the level of the sorting network (equivalently, recursion depth).
505505
`vals` is the array of values of type `T` that is either being `sort`-ed or `sortperm`-ed.
506-
`inds` is an array of indices of type `J` that gets permuted in `sortperm!`.
506+
`inds` is an array of indices of type `J` that gets permuted in `sortperm!` (standard 1-indexed)
507507
`i1`, `i2` index either `vals` or `inds` depending on the operation.
508508
`lo`, `n`, and `m` are integers of type `I` used to denote/calculate ranges as
509-
described in the recursive algorithm link above.
509+
described in the recursive algorithm link above. Note these follow the 0-indexing
510+
convention from the above source.
510511
"""
511512
module BitonicSort
512513
export bitonic_sort!
@@ -729,6 +730,7 @@ end
729730
"""
730731
For sort/sort! `c`, allocate and return shared memory view of `c`
731732
Each view is indexed along block x dim: one view per pseudo-block
733+
`index` is expected to be from a 0-indexing context
732734
"""
733735
@inline function initialize_shmem!(vals::AbstractArray{T}, index::I, in_range, offset = zero(I)) where {T,I}
734736
swap = CuDynamicSharedArray(T, (blockDim().x, blockDim().y), offset)
@@ -741,26 +743,21 @@ end
741743

742744
"""
743745
For sortperm/sortperm!, allocate and return shared memory views of `c` and index
744-
array. Each view is indexed along block x dim: one view per pseudo-block
746+
array. Each view is indexed along block x dim: one view per pseudo-block.
747+
`index` is expected to be from a 0-indexing context, but the indices stored in
748+
`val_inds` are expected to be 1-indexed
745749
"""
746750
@inline function initialize_shmem!(vals_inds::Tuple{AbstractArray{T},AbstractArray{J}}, index, in_range) where {T,J}
747-
# NB: I tried creating both shmem arrays with `initialize_shmem!`
748-
# but the behavior changed - maybe it's necessary to alloc both before
749-
# writing to either?
750-
offset = prod(blockDim()) * sizeof(T)
751+
offset = prod(blockDim()) * sizeof(J)
751752
vals, inds = vals_inds
752-
swap_vals = CuDynamicSharedArray(T, (blockDim().x, blockDim().y))
753-
inds_view = initialize_shmem!(inds, index, in_range, offset)
754-
vals_view = @view swap_vals[:, threadIdx().y]
755-
if in_range
756-
@inbounds vals_view[threadIdx().x] = vals[inds_view[threadIdx().x]]
757-
end
758-
sync_threads()
753+
inds_view = initialize_shmem!(inds, index, in_range)
754+
vals_view = initialize_shmem!(vals, inds_view[threadIdx().x] - one(J), in_range, offset)
759755
return vals_view, inds_view
760756
end
761757

762758
"""
763759
For sort/sort!, copy shmem view `swap` back into global array `c`
760+
`index` is expected to be from a 0-indexing context
764761
"""
765762
@inline function finalize_shmem!(vals::AbstractArray, swap::AbstractArray, index::I, in_range::Bool) where {I}
766763
if in_range
@@ -770,6 +767,8 @@ end
770767

771768
"""
772769
For sortperm/sortperm!, copy shmem view `swap` back to global index array
770+
`index` is expected to be from a 0-indexing context, but the indices stored in
771+
`val_inds` are expected to be 1-indexed
773772
"""
774773
@inline function finalize_shmem!(vals_inds::Tuple, swap::Tuple, index, in_range::Bool)
775774
vals, inds = vals_inds

0 commit comments

Comments
 (0)