88 */
99
1010#include " bitonic_sort.cuh"
11- #include < stdio.h>
1211
1312/* *
1413 * Swap
@@ -51,41 +50,42 @@ __device__ int swap(int x, int mask, int dir) {
5150 */
5251__global__ void smemBitonicSort (int *arr, int size) {
5352 // shared memory for block of 1024 threads
54- __shared__ int smem[1 << 10 ];
53+ extern __shared__ int smem[];
5554
5655 // local thread id in block
5756 int thread_id = threadIdx .x ;
5857
5958 // seed shared memory array with value from global array
6059 // pad overflow threads with INT_MAX
6160 smem[thread_id] = thread_id < size ? arr[thread_id] : INT_MAX;
61+ __syncthreads ();
6262
6363 // make bitonic sequence and sort
64- for (int i = 0 ; (1 << i) <= blockDim . x ; i++) {
64+ for (int i = 0 ; (1 << i) <= size ; i++) {
6565 for (int j = 0 ; j <= i; j++) {
6666 // distance between caller and source lanes
67- int offset = 1 << (i - j);
68- // number of elements in each sorted subset
69- int sort_size = offset << 1 ;
70- // id into smem array
71- int arr_id =
72- (thread_id / sort_size * sort_size) + (thread_id % sort_size / 2 ) ^
73- (thread_id % 2 * offset); // apply xor to odd threads
74- printf (" thread %d arr %d\n " , thread_id, arr_id);
67+ int offset = 1 << (i - j - 1 );
7568 // direction to swap caller and source lanes
7669 int dir;
7770 // only alternate direction when forming bitonic sequence
7871 if (1 << i == blockDim .x ) {
79- dir = (arr_id >> (i - j)) & 1 ;
72+ dir = (thread_id >> (i - j)) & 1 ;
8073 } else {
81- dir = (arr_id >> (i + 1 )) & 1 ^ (arr_id >> (i - j)) & 1 ;
74+ dir = (thread_id >> (i + 1 )) & 1 ^ (thread_id >> (i - j)) & 1 ;
75+ }
76+ if (1 << i <= warpSize ) {
77+ smem[thread_id] = swap (smem[thread_id], offset, dir);
78+ } else {
79+ __syncthreads ();
80+ int partner_val = smem[thread_id ^ offset];
81+ int val = smem[thread_id];
82+ // compare and swap elements
83+ smem[thread_id] = val < partner_val == dir ? val : partner_val;
84+ smem[thread_id ^ offset] = val < partner_val == dir ? partner_val : val;
8285 }
83- // elements to compare and swap are directly next to eachother in warp
84- smem[arr_id] = swap (smem[arr_id], 1 , dir);
85- // wait for all warps to finish swap before going to next layer
86- __syncthreads ();
8786 }
8887 }
88+ __syncthreads ();
8989
9090 // update value in array with sorted value
9191 if (thread_id < size) {
@@ -95,6 +95,6 @@ __global__ void smemBitonicSort(int *arr, int size) {
9595
9696void launchBitonicSort (int *arr, int size) {
9797 const int BLOCK_SIZE = 1024 ;
98- smemBitonicSort<<<( size + ( BLOCK_SIZE - 1 )) / BLOCK_SIZE, BLOCK_SIZE>>> (arr,
99- size);
98+ smemBitonicSort<<<size / BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE * sizeof ( int ) >>> (
99+ arr, size);
100100}
0 commit comments