Skip to content

Commit 724ffd0

Browse files
committed
use warp suffle in smem sorting
1 parent 9fa193b commit 724ffd0

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ bool isSorted(int *arr, int size) {
1313
}
1414

1515
int main() {
16-
const int SIZE = 4096; // Must be a multiple of 32 for this example
17-
const int BLOCK_SIZE = 256;
16+
const int SIZE = 1024; // Must be a multiple of 32 for this example
1817

1918
// Allocate and initialize host array
2019
int *h_arr = new int[SIZE];

smem_bitonic_sort.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
*/
99

1010
#include "bitonic_sort.cuh"
11-
#include <stdio.h>
1211

1312
/**
1413
* Swap
@@ -51,41 +50,42 @@ __device__ int swap(int x, int mask, int dir) {
5150
*/
5251
__global__ void smemBitonicSort(int *arr, int size) {
5352
// shared memory for block of 1024 threads
54-
__shared__ int smem[1 << 10];
53+
extern __shared__ int smem[];
5554

5655
// local thread id in block
5756
int thread_id = threadIdx.x;
5857

5958
// seed shared memory array with value from global array
6059
// pad overflow threads with INT_MAX
6160
smem[thread_id] = thread_id < size ? arr[thread_id] : INT_MAX;
61+
__syncthreads();
6262

6363
// make bitonic sequence and sort
64-
for (int i = 0; (1 << i) <= blockDim.x; i++) {
64+
for (int i = 0; (1 << i) <= size; i++) {
6565
for (int j = 0; j <= i; j++) {
6666
// distance between caller and source lanes
67-
int offset = 1 << (i - j);
68-
// number of elements in each sorted subset
69-
int sort_size = offset << 1;
70-
// id into smem array
71-
int arr_id =
72-
(thread_id / sort_size * sort_size) + (thread_id % sort_size / 2) ^
73-
(thread_id % 2 * offset); // apply xor to odd threads
74-
printf("thread %d arr %d\n", thread_id, arr_id);
67+
int offset = 1 << (i - j - 1);
7568
// direction to swap caller and source lanes
7669
int dir;
7770
// only alternate direction when forming bitonic sequence
7871
if (1 << i == blockDim.x) {
79-
dir = (arr_id >> (i - j)) & 1;
72+
dir = (thread_id >> (i - j)) & 1;
8073
} else {
81-
dir = (arr_id >> (i + 1)) & 1 ^ (arr_id >> (i - j)) & 1;
74+
dir = (thread_id >> (i + 1)) & 1 ^ (thread_id >> (i - j)) & 1;
75+
}
76+
if (1 << i <= warpSize) {
77+
smem[thread_id] = swap(smem[thread_id], offset, dir);
78+
} else {
79+
__syncthreads();
80+
int partner_val = smem[thread_id ^ offset];
81+
int val = smem[thread_id];
82+
// compare and swap elements
83+
smem[thread_id] = val < partner_val == dir ? val : partner_val;
84+
smem[thread_id ^ offset] = val < partner_val == dir ? partner_val : val;
8285
}
83-
// elements to compare and swap are directly next to eachother in warp
84-
smem[arr_id] = swap(smem[arr_id], 1, dir);
85-
// wait for all warps to finish swap before going to next layer
86-
__syncthreads();
8786
}
8887
}
88+
__syncthreads();
8989

9090
// update value in array with sorted value
9191
if (thread_id < size) {
@@ -95,6 +95,6 @@ __global__ void smemBitonicSort(int *arr, int size) {
9595

9696
void launchBitonicSort(int *arr, int size) {
9797
const int BLOCK_SIZE = 1024;
98-
smemBitonicSort<<<(size + (BLOCK_SIZE - 1)) / BLOCK_SIZE, BLOCK_SIZE>>>(arr,
99-
size);
98+
smemBitonicSort<<<size / BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE * sizeof(int)>>>(
99+
arr, size);
100100
}

0 commit comments

Comments
 (0)