@@ -36,36 +36,62 @@ __device__ __forceinline__ void unpack_key(uint64_t k, uint32_t& meas,
3636}
3737
3838__device__ void count_tracks (int tid, int * sh_n_meas, int n_tracks,
39- unsigned int & bound, unsigned int & count,
40- bool & stop) {
39+ unsigned int & bound, unsigned int & count) {
4140
4241 unsigned int add = 0 ;
43- unsigned int offset = 0 ;
44- for (unsigned int stride = 1 ; stride < (n_tracks - count); stride *= 2 ) {
45- if ((count + tid + stride) < n_tracks) {
46- sh_n_meas[count + tid] += sh_n_meas[count + tid + stride];
42+
43+ // --- Warp-level phase: handle strides < 32 using warp shuffle (no
44+ // __syncthreads needed) ---
45+ const int lane = threadIdx .x & 31 ;
46+ const unsigned int full_mask = 0xFFFFFFFFu ;
47+
48+ // Load this thread's value into a register if it's in range
49+ int v = (tid < n_tracks) ? sh_n_meas[tid] : 0 ;
50+
51+ // Mask for active lanes in this warp
52+ const unsigned int mask = __ballot_sync (full_mask, tid < n_tracks);
53+
54+ const int max_stride = min (n_tracks, 32 );
55+
56+ for (int stride = 1 ; stride < max_stride; stride <<= 1 ) {
57+ // Accumulate neighbor's value via warp shuffle
58+ unsigned int other = __shfl_down_sync (mask, v, stride);
59+ if (lane + stride < 32 && (tid + stride) < n_tracks) {
60+ v += other;
4761 }
48- __syncthreads ();
4962
50- if (sh_n_meas[count] < bound) {
51- if (tid == 0 ) {
52- offset = sh_n_meas[count];
53- add = stride * 2 ;
63+ // Thread 0 can directly check its register value in the warp phase
64+ if (tid == 0 ) {
65+ if (v < bound) {
66+ add = stride << 1 ;
5467 }
5568 }
69+ }
5670
57- __syncthreads ();
71+ // Write warp-phase result back to shared memory
72+ if (tid < n_tracks) {
73+ sh_n_meas[tid] = static_cast <int >(v);
5874 }
75+ __syncthreads ();
5976
60- if (tid == 0 ) {
61- bound -= offset;
62- count += add;
77+ // --- Block-level phase: handle strides >= 32 (minimal required
78+ // synchronizations) ---
79+ for (int stride = 32 ; stride < n_tracks; stride <<= 1 ) {
80+ if ((tid + stride) < n_tracks) {
81+ sh_n_meas[tid] += sh_n_meas[tid + stride];
82+ }
83+ __syncthreads ();
6384
64- if (add == 0 ) {
65- stop = true ;
85+ if (tid == 0 && sh_n_meas[ 0 ] < bound ) {
86+ add = stride << 1 ;
6687 }
88+ __syncthreads ();
6789 }
6890
91+ // --- Final update ---
92+ if (tid == 0 ) {
93+ count += add;
94+ }
6995 __syncthreads ();
7096}
7197
@@ -96,7 +122,6 @@ __launch_bounds__(512) __global__
96122 __shared__ unsigned int n_tracks_to_iterate;
97123 __shared__ unsigned int min_thread;
98124 __shared__ unsigned int N;
99- __shared__ bool stop;
100125 __shared__ unsigned int n_updating_threads;
101126
102127 auto threadIndex = threadIdx .x ;
@@ -135,7 +160,6 @@ __launch_bounds__(512) __global__
135160 N = 1 ;
136161 n_tracks_to_iterate = 0 ;
137162 min_thread = std::numeric_limits<unsigned int >::max ();
138- stop = false ;
139163 }
140164
141165 __syncthreads ();
@@ -158,26 +182,8 @@ __launch_bounds__(512) __global__
158182 * Count the number of removable tracks
159183 ****************************************/
160184
161- // @TODO: Improve the logic
162185 count_tracks (threadIdx .x , sh_buffer, n_tracks_total, bound,
163- n_tracks_to_iterate, stop);
164- /*
165- for (int i = 0; i < 100; i++) {
166- count_tracks(threadIdx.x, shared_n_meas, n_tracks_total, bound,
167- n_tracks_to_iterate, stop);
168- __syncthreads();
169- if (stop)
170- break;
171-
172- if (gid >= 0 && static_cast<unsigned int>(gid) < sorted_ids.size()) {
173- const auto trk_id = sorted_ids[gid];
174- if (trk_id < n_meas.size()) {
175- shared_n_meas[threadIndex] = n_meas[trk_id];
176- }
177- }
178- __syncthreads();
179- }
180- */
186+ n_tracks_to_iterate);
181187
182188 if (threadIndex == 0 && n_tracks_to_iterate == 0 ) {
183189 n_tracks_to_iterate = 1 ;
0 commit comments