Skip to content

Commit a1a7a57

Browse files
committed
sort in smem
1 parent de7da54 commit a1a7a57

File tree

1 file changed

+37
-32
lines changed

1 file changed

+37
-32
lines changed

smem_bitonic_sort.cu

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,30 @@
1717
* @param x caller line id's value
1818
* @param mask source lane id = caller line id ^ mask
1919
* @param dir direction to swap
20+
* @param arr shared memory
2021
*
21-
* @return min or max of source and caller
2222
*/
23-
__device__ int swap(int x, int mask, int dir) {
23+
__device__ void swap(int x, int mask, int dir, int *arr) {
2424
// get correspondin element to x in butterfly diagram
25-
int y = __shfl_xor_sync(0xffffffff, x, mask);
26-
// return smaller or larger value based on direction of swap
27-
return x < y == dir ? y : x;
25+
int y = x ^ mask;
26+
// lower ids thread perform swap
27+
if (y > x) {
28+
if (dir) {
29+
// sort ascending
30+
if (arr[x] < arr[y]) {
31+
int temp = arr[x];
32+
arr[x] = arr[y];
33+
arr[y] = temp;
34+
}
35+
} else {
36+
// sort descending
37+
if (arr[x] > arr[y]) {
38+
int temp = arr[x];
39+
arr[x] = arr[y];
40+
arr[y] = temp;
41+
}
42+
}
43+
}
2844
}
2945

3046
/**
@@ -53,48 +69,37 @@ __global__ void smemBitonicSort(int *arr, int size) {
5369
extern __shared__ int smem[];
5470

5571
// local thread id in block
56-
int thread_id = threadIdx.x;
72+
int thread_id = threadIdx.x + blockIdx.x * blockDim.x;
73+
// id if thread within its block
74+
int local_id = threadIdx.x;
5775

5876
// seed shared memory array with value from global array
5977
// pad overflow threads with INT_MAX
60-
smem[thread_id] = thread_id < size ? arr[thread_id] : INT_MAX;
78+
smem[local_id] = thread_id < size ? arr[thread_id] : INT_MAX;
6179
__syncthreads();
6280

6381
// make bitonic sequence and sort
64-
for (int i = 0; (1 << i) <= size; i++) {
65-
for (int j = 0; j <= i; j++) {
82+
for (int i = 0; (1 << i) <= blockDim.x; i++) {
83+
for (int j = 1; j <= i; j++) {
6684
// distance between caller and source lanes
67-
int offset = 1 << (i - j - 1);
68-
// direction to swap caller and source lanes
69-
int dir;
70-
// only alternate direction when forming bitonic sequence
71-
if (1 << i == blockDim.x) {
72-
dir = (thread_id >> (i - j)) & 1;
73-
} else {
74-
dir = (thread_id >> (i + 1)) & 1 ^ (thread_id >> (i - j)) & 1;
75-
}
76-
if (1 << i <= warpSize) {
77-
smem[thread_id] = swap(smem[thread_id], offset, dir);
78-
} else {
79-
__syncthreads();
80-
int partner_val = smem[thread_id ^ offset];
81-
int val = smem[thread_id];
82-
// compare and swap elements
83-
smem[thread_id] = val < partner_val == dir ? val : partner_val;
84-
smem[thread_id ^ offset] = val < partner_val == dir ? partner_val : val;
85-
}
85+
int mask = 1 << (i - j);
86+
87+
// perform compare and swap
88+
int dir = local_id & (1 << i);
89+
swap(local_id, mask, dir, smem);
90+
__syncthreads();
8691
}
8792
}
88-
__syncthreads();
8993

9094
// update value in array with sorted value
9195
if (thread_id < size) {
92-
arr[thread_id] = smem[thread_id];
96+
arr[thread_id] = smem[local_id];
9397
}
98+
__syncthreads();
9499
}
95100

96101
void launchBitonicSort(int *arr, int size) {
97102
const int BLOCK_SIZE = 1024;
98-
smemBitonicSort<<<size / BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE * sizeof(int)>>>(
99-
arr, size);
103+
smemBitonicSort<<<(size + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE,
104+
BLOCK_SIZE * sizeof(int)>>>(arr, size);
100105
}

0 commit comments

Comments
 (0)