Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 99 additions & 91 deletions device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,22 @@ __launch_bounds__(1024) __global__
return;
}

auto gid = threadIdx.x / nThreads_per_track +
blockIdx.x * (blockDim.x / nThreads_per_track);
const unsigned int n_accepted = *(payload.n_accepted);
// group (track) index in this block
const int lane = threadIdx.x % nThreads_per_track;
const int group = threadIdx.x / nThreads_per_track;
const bool leader = (lane == 0);

auto N = *(payload.n_updated_tracks);
auto gid = group + blockIdx.x * (blockDim.x / nThreads_per_track);
const unsigned int n_accepted = *(payload.n_accepted);
const int N = *(payload.n_updated_tracks);

int neff_threads = (N + nThreads_per_track - 1) / nThreads_per_track;

if (neff_threads > nThreads_per_track) {
neff_threads = nThreads_per_track;
}

bool is_valid_thread = true;
if (threadIdx.x % nThreads_per_track >= neff_threads || gid >= n_accepted) {
if (lane >= neff_threads || gid >= static_cast<int>(n_accepted)) {
is_valid_thread = false;
}

Expand All @@ -80,136 +82,142 @@ __launch_bounds__(1024) __global__
payload.temp_sorted_ids_view);

__shared__ int shifted_indices[1024];
auto& shifted_idx = shifted_indices[threadIdx.x / nThreads_per_track];
auto& shifted_idx = shifted_indices[group];

unsigned int tid = std::numeric_limits<unsigned int>::max();

if (is_valid_thread) {

tid = sorted_ids[gid];
auto rel_sh_ref = rel_shared[tid];
auto pval_ref = pvals[tid];
const auto rel_sh_ref = rel_shared[tid];
const auto pval_ref = pvals[tid];

// initialize once by any lane (all lanes see same reference)
shifted_idx = static_cast<int>(gid);

int stride = (N + neff_threads - 1) / neff_threads;

int ini_idx = stride * (threadIdx.x % nThreads_per_track);
int fin_idx = std::min(ini_idx + stride, static_cast<int>(N));
// work partition
const int stride = (N + neff_threads - 1) / neff_threads;
const int ini_idx = stride * lane;
const int fin_idx = min(ini_idx + stride, N);

if (is_updated[tid]) {

if (gid > 0) {
// ---- group leader: compute base left-shift via binary search over
// valid (non-updated) slots
if (gid > 0 && leader) {

unsigned int left = 0;
unsigned int right = gid;

bool first_iteration = true;

if (threadIdx.x % nThreads_per_track == 0) {

while (right > left) {
while (right > left) {

const bool find_left = find_valid_index(
left, 0, gid, sorted_ids, is_updated);
const bool find_left =
find_valid_index(left, 0, gid, sorted_ids, is_updated);
if (!find_left)
break;

if (!find_left) {
break;
}
const bool find_right =
find_valid_index(right, 0, gid, sorted_ids, is_updated);
if (!find_right)
break;

const bool find_right = find_valid_index(
right, 0, gid, sorted_ids, is_updated);
if (first_iteration) {
const auto right_idx = sorted_ids[right];
const auto rel_sh = rel_shared[right_idx];
const auto pval = pvals[right_idx];

if (!find_right) {
if (rel_sh < rel_sh_ref ||
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
left = gid;
break;
}

if (first_iteration) {
const auto right_idx = sorted_ids[right];
auto rel_sh = rel_shared[right_idx];
auto pval = pvals[right_idx];

if (rel_sh < rel_sh_ref ||
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
left = gid;
break;
}
}

first_iteration = false;

unsigned int mid = left + (right - left) / 2;

const bool find_mid = find_valid_index(
mid, left, right - 1, sorted_ids, is_updated);

if (find_mid) {

const auto mid_idx = sorted_ids[mid];
auto rel_sh = rel_shared[mid_idx];
auto pval = pvals[mid_idx];

if (rel_sh < rel_sh_ref ||
(rel_sh == rel_sh_ref && pval >= pval_ref)) {

left = mid + 1;
} else {
right = mid;
}
}
first_iteration = false;

unsigned int mid = left + (right - left) / 2;
const bool find_mid = find_valid_index(
mid, left, right - 1, sorted_ids, is_updated);

if (find_mid) {
const auto mid_idx = sorted_ids[mid];
const auto rel_sh = rel_shared[mid_idx];
const auto pval = pvals[mid_idx];

if (rel_sh < rel_sh_ref ||
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
left = mid + 1;
} else {
right = mid;
}
}
}

int delta = delta =
gid - left - (prefix_sums[gid] - prefix_sums[left]);

if (!is_updated[sorted_ids[left]]) {
delta++;
}
// BUGFIX: remove duplicate assignment ("delta = delta = ...")
int delta = static_cast<int>(
gid - left - (prefix_sums[gid] - prefix_sums[left]));

atomicAdd(&shifted_idx, -delta);
if (!is_updated[sorted_ids[left]]) {
delta += 1;
}

atomicAdd(&shifted_idx, -delta);
}

for (int i = ini_idx; i < fin_idx; i++) {
// ---- all lanes: single-pass over [ini_idx, fin_idx) to (a) count
// left-updates, (b) find offset
int local_delta = 0;
int local_offset = -1;

auto id = updated_tracks[i];
for (int i = ini_idx; i < fin_idx; ++i) {
const auto id = updated_tracks[i];

if (inverted_ids[id] < gid) {
atomicAdd(&shifted_idx, -1);
// how many updated tracks originally to the left of gid
if (inverted_ids[id] < static_cast<unsigned int>(gid)) {
local_delta -= 1;
}
}

int offset = 0;
for (int i = ini_idx; i < fin_idx; i++) {
if (updated_tracks[i] == tid) {
offset = i;
break;
// find the position of my tid in updated_tracks
if (local_offset < 0 && id == tid) {
local_offset =
i; // if i == 0, adding 0 is a no-op (original logic)
}
}
if (offset != 0) {
atomicAdd(&shifted_idx, offset);
}
} else {

for (int i = ini_idx; i < fin_idx; i++) {

auto id = updated_tracks[i];
auto rel_sh = rel_shared[id];
auto pval = pvals[id];
if (local_delta != 0) {
atomicAdd(&shifted_idx, local_delta);
}
if (local_offset > 0) {
atomicAdd(&shifted_idx, local_offset);
}

if (inverted_ids[id] > gid) {
if (rel_sh < rel_sh_ref) {
atomicAdd(&shifted_idx, 1);
} else if (rel_sh == rel_sh_ref && pval > pval_ref) {
atomicAdd(&shifted_idx, 1);
} else {
// tid is NOT updated: count how many updated tracks should move to
// the right of me
int local_delta = 0;

for (int i = ini_idx; i < fin_idx; ++i) {
const auto id = updated_tracks[i];
if (inverted_ids[id] > static_cast<unsigned int>(gid)) {
const auto rel_sh = rel_shared[id];
const auto pval = pvals[id];

if (rel_sh < rel_sh_ref ||
(rel_sh == rel_sh_ref && pval > pval_ref)) {
local_delta += 1;
}
}
}

if (local_delta != 0) {
atomicAdd(&shifted_idx, local_delta);
}
}
}

__syncthreads();

if (is_valid_thread && (threadIdx.x % nThreads_per_track) == 0) {
if (is_valid_thread && leader) {
temp_sorted_ids.at(shifted_idx) = tid;
}
}
Expand Down
Loading