diff --git a/device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu b/device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu index ac3330c287..475d0aad27 100644 --- a/device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu +++ b/device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu @@ -48,20 +48,22 @@ __launch_bounds__(1024) __global__ return; } - auto gid = threadIdx.x / nThreads_per_track + - blockIdx.x * (blockDim.x / nThreads_per_track); - const unsigned int n_accepted = *(payload.n_accepted); + // group (track) index in this block + const int lane = threadIdx.x % nThreads_per_track; + const int group = threadIdx.x / nThreads_per_track; + const bool leader = (lane == 0); - auto N = *(payload.n_updated_tracks); + auto gid = group + blockIdx.x * (blockDim.x / nThreads_per_track); + const unsigned int n_accepted = *(payload.n_accepted); + const int N = *(payload.n_updated_tracks); int neff_threads = (N + nThreads_per_track - 1) / nThreads_per_track; - if (neff_threads > nThreads_per_track) { neff_threads = nThreads_per_track; } bool is_valid_thread = true; - if (threadIdx.x % nThreads_per_track >= neff_threads || gid >= n_accepted) { + if (lane >= neff_threads || gid >= static_cast(n_accepted)) { is_valid_thread = false; } @@ -80,136 +82,142 @@ __launch_bounds__(1024) __global__ payload.temp_sorted_ids_view); __shared__ int shifted_indices[1024]; - auto& shifted_idx = shifted_indices[threadIdx.x / nThreads_per_track]; + auto& shifted_idx = shifted_indices[group]; + unsigned int tid = std::numeric_limits::max(); if (is_valid_thread) { tid = sorted_ids[gid]; - auto rel_sh_ref = rel_shared[tid]; - auto pval_ref = pvals[tid]; + const auto rel_sh_ref = rel_shared[tid]; + const auto pval_ref = pvals[tid]; + // initialize once by any lane (all lanes see same reference) shifted_idx = static_cast(gid); - int stride = (N + neff_threads - 1) / neff_threads; - - int ini_idx = stride * (threadIdx.x % nThreads_per_track); - int fin_idx = std::min(ini_idx + stride, static_cast(N)); + // work partition + const int stride = (N + neff_threads - 1) / neff_threads; + const int ini_idx = stride * lane; + const int fin_idx = min(ini_idx + stride, N); if (is_updated[tid]) { - if (gid > 0) { + // ---- group leader: compute base left-shift via binary search over + // valid (non-updated) slots + if (gid > 0 && leader) { unsigned int left = 0; unsigned int right = gid; - bool first_iteration = true; - if (threadIdx.x % nThreads_per_track == 0) { - - while (right > left) { + while (right > left) { - const bool find_left = find_valid_index( - left, 0, gid, sorted_ids, is_updated); + const bool find_left = + find_valid_index(left, 0, gid, sorted_ids, is_updated); + if (!find_left) + break; - if (!find_left) { - break; - } + const bool find_right = + find_valid_index(right, 0, gid, sorted_ids, is_updated); + if (!find_right) + break; - const bool find_right = find_valid_index( - right, 0, gid, sorted_ids, is_updated); + if (first_iteration) { + const auto right_idx = sorted_ids[right]; + const auto rel_sh = rel_shared[right_idx]; + const auto pval = pvals[right_idx]; - if (!find_right) { + if (rel_sh < rel_sh_ref || + (rel_sh == rel_sh_ref && pval >= pval_ref)) { + left = gid; break; } - - if (first_iteration) { - const auto right_idx = sorted_ids[right]; - auto rel_sh = rel_shared[right_idx]; - auto pval = pvals[right_idx]; - - if (rel_sh < rel_sh_ref || - (rel_sh == rel_sh_ref && pval >= pval_ref)) { - left = gid; - break; - } - } - - first_iteration = false; - - unsigned int mid = left + (right - left) / 2; - - const bool find_mid = find_valid_index( - mid, left, right - 1, sorted_ids, is_updated); - - if (find_mid) { - - const auto mid_idx = sorted_ids[mid]; - auto rel_sh = rel_shared[mid_idx]; - auto pval = pvals[mid_idx]; - - if (rel_sh < rel_sh_ref || - (rel_sh == rel_sh_ref && pval >= pval_ref)) { - - left = mid + 1; - } else { - right = mid; - } + } + first_iteration = false; + + unsigned int mid = left + (right - left) / 2; + const bool find_mid = find_valid_index( + mid, left, right - 1, sorted_ids, is_updated); + + if (find_mid) { + const auto mid_idx = sorted_ids[mid]; + const auto rel_sh = rel_shared[mid_idx]; + const auto pval = pvals[mid_idx]; + + if (rel_sh < rel_sh_ref || + (rel_sh == rel_sh_ref && pval >= pval_ref)) { + left = mid + 1; + } else { + right = mid; } } + } - int delta = delta = - gid - left - (prefix_sums[gid] - prefix_sums[left]); - - if (!is_updated[sorted_ids[left]]) { - delta++; - } + // BUGFIX: remove duplicate assignment ("delta = delta = ...") + int delta = static_cast( + gid - left - (prefix_sums[gid] - prefix_sums[left])); - atomicAdd(&shifted_idx, -delta); + if (!is_updated[sorted_ids[left]]) { + delta += 1; } + + atomicAdd(&shifted_idx, -delta); } - for (int i = ini_idx; i < fin_idx; i++) { + // ---- all lanes: single-pass over [ini_idx, fin_idx) to (a) count + // left-updates, (b) find offset + int local_delta = 0; + int local_offset = -1; - auto id = updated_tracks[i]; + for (int i = ini_idx; i < fin_idx; ++i) { + const auto id = updated_tracks[i]; - if (inverted_ids[id] < gid) { - atomicAdd(&shifted_idx, -1); + // how many updated tracks originally to the left of gid + if (inverted_ids[id] < static_cast(gid)) { + local_delta -= 1; } - } - int offset = 0; - for (int i = ini_idx; i < fin_idx; i++) { - if (updated_tracks[i] == tid) { - offset = i; - break; + // find the position of my tid in updated_tracks + if (local_offset < 0 && id == tid) { + local_offset = + i; // if i == 0, adding 0 is a no-op (original logic) } } - if (offset != 0) { - atomicAdd(&shifted_idx, offset); - } - } else { - for (int i = ini_idx; i < fin_idx; i++) { - - auto id = updated_tracks[i]; - auto rel_sh = rel_shared[id]; - auto pval = pvals[id]; + if (local_delta != 0) { + atomicAdd(&shifted_idx, local_delta); + } + if (local_offset > 0) { + atomicAdd(&shifted_idx, local_offset); + } - if (inverted_ids[id] > gid) { - if (rel_sh < rel_sh_ref) { - atomicAdd(&shifted_idx, 1); - } else if (rel_sh == rel_sh_ref && pval > pval_ref) { - atomicAdd(&shifted_idx, 1); + } else { + // tid is NOT updated: count how many updated tracks should move to + // the right of me + int local_delta = 0; + + for (int i = ini_idx; i < fin_idx; ++i) { + const auto id = updated_tracks[i]; + if (inverted_ids[id] > static_cast(gid)) { + const auto rel_sh = rel_shared[id]; + const auto pval = pvals[id]; + + if (rel_sh < rel_sh_ref || + (rel_sh == rel_sh_ref && pval > pval_ref)) { + local_delta += 1; } } } + + if (local_delta != 0) { + atomicAdd(&shifted_idx, local_delta); + } } } __syncthreads(); - if (is_valid_thread && (threadIdx.x % nThreads_per_track) == 0) { + if (is_valid_thread && leader) { temp_sorted_ids.at(shifted_idx) = tid; } }