Skip to content

Commit ca2cfcc

Browse files
Fix bugs in the greedy resolver (#1083)
* Fix several bugs in the greedy resolver which includes wrong size allocation, deadlock and unsafe usage of a local varaiable * Change the test name to register in CI run * Add CUDA keywords to the test suites * Rename the variables and use local memory for n_accepted_prev * Use at instead of [] operator --------- Co-authored-by: Stephen Nicholas Swatman <[email protected]>
1 parent a3b4e0c commit ca2cfcc

File tree

4 files changed

+114
-94
lines changed

4 files changed

+114
-94
lines changed

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ greedy_ambiguity_resolution_algorithm::operator()(
233233

234234
// Unique measurement ids
235235
vecmem::data::vector_buffer<measurement_id_type>
236-
meas_id_to_unique_id_buffer{max_meas.measurement_id, m_mr.main};
236+
meas_id_to_unique_id_buffer{max_meas.measurement_id + 1, m_mr.main};
237237

238238
// Make meas_id to meas vector
239239
{

device/cuda/src/ambiguity_resolution/kernels/count_removable_tracks.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ __device__ void count_tracks(int tid, int* sh_n_meas, int n_tracks,
3636
offset = sh_n_meas[count];
3737
add = stride * 2;
3838
}
39-
__syncthreads();
4039
}
40+
41+
__syncthreads();
4142
}
4243

4344
if (tid == 0) {

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 87 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
3838

3939
auto threadIndex = threadIdx.x;
4040

41+
bool is_valid_thread = false;
42+
bool is_duplicate = true;
43+
4144
shared_tids[threadIndex] = std::numeric_limits<unsigned int>::max();
4245

4346
vecmem::device_vector<const unsigned int> sorted_ids(
@@ -62,112 +65,128 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
6265
payload.meas_to_remove_view);
6366
vecmem::device_vector<unsigned int> threads(payload.threads_view);
6467

65-
auto n_accepted_prev = (*payload.n_accepted);
68+
const unsigned n_accepted_prev = *(payload.n_accepted);
69+
70+
__syncthreads();
71+
6672
if (threadIndex == 0) {
6773
(*payload.n_accepted) -= *(payload.n_removable_tracks);
6874
}
6975

7076
if (threadIndex < *(payload.n_meas_to_remove)) {
7177
sh_meas_ids[threadIndex] = meas_to_remove[threadIndex];
7278
sh_threads[threadIndex] = threads[threadIndex];
73-
} else {
74-
return;
79+
is_valid_thread = true;
7580
}
7681

77-
const auto id = sh_meas_ids[threadIndex];
82+
__syncthreads();
83+
84+
if (is_valid_thread) {
85+
86+
const auto id = sh_meas_ids[threadIndex];
87+
is_duplicate = false;
7888

79-
bool is_duplicate = false;
80-
for (unsigned int i = 0; i < threadIndex; ++i) {
81-
if (sh_meas_ids[i] == id) {
82-
is_duplicate = true;
83-
break;
89+
for (unsigned int i = 0; i < threadIndex; ++i) {
90+
if (sh_meas_ids[i] == id) {
91+
is_duplicate = true;
92+
break;
93+
}
8494
}
8595
}
86-
if (is_duplicate) {
87-
return;
88-
}
8996

90-
const auto unique_meas_idx = meas_id_to_unique_id.at(id);
97+
bool active = false;
9198

92-
// If there is only one track associated with measurement, the
93-
// number of shared measurement can be reduced by one
94-
const auto& tracks = tracks_per_measurement[unique_meas_idx];
95-
auto track_status = track_status_per_measurement[unique_meas_idx];
99+
if (!is_duplicate && is_valid_thread) {
96100

97-
auto trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[threadIndex]];
101+
const auto id = sh_meas_ids[threadIndex];
102+
const auto unique_meas_idx = meas_id_to_unique_id.at(id);
98103

99-
unsigned int worst_idx =
100-
thrust::find(thrust::seq, tracks.begin(), tracks.end(), trk_id) -
101-
tracks.begin();
104+
// If there is only one track associated with measurement, the
105+
// number of shared measurement can be reduced by one
106+
const auto& tracks = tracks_per_measurement[unique_meas_idx];
107+
auto track_status = track_status_per_measurement[unique_meas_idx];
102108

103-
track_status[worst_idx] = 0;
109+
auto trk_id =
110+
sorted_ids.at(n_accepted_prev - 1 - sh_threads[threadIndex]);
104111

105-
int n_sharing_tracks = 1;
106-
for (unsigned int i = threadIndex + 1; i < *(payload.n_meas_to_remove);
107-
++i) {
112+
unsigned int worst_idx =
113+
thrust::find(thrust::seq, tracks.begin(), tracks.end(), trk_id) -
114+
tracks.begin();
108115

109-
if (sh_meas_ids[i] == id && sh_threads[i] != sh_threads[i - 1]) {
110-
n_sharing_tracks++;
116+
track_status[worst_idx] = 0;
111117

112-
trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[i]];
118+
int n_sharing_tracks = 1;
119+
for (unsigned int i = threadIndex + 1; i < *(payload.n_meas_to_remove);
120+
++i) {
113121

114-
worst_idx = thrust::find(thrust::seq, tracks.begin(), tracks.end(),
115-
trk_id) -
116-
tracks.begin();
122+
if (sh_meas_ids[i] == id && sh_threads[i] != sh_threads[i - 1]) {
123+
n_sharing_tracks++;
117124

118-
track_status[worst_idx] = 0;
125+
trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[i]];
119126

120-
} else if (sh_meas_ids[i] != id) {
121-
break;
122-
}
123-
}
127+
worst_idx = thrust::find(thrust::seq, tracks.begin(),
128+
tracks.end(), trk_id) -
129+
tracks.begin();
124130

125-
vecmem::device_atomic_ref<unsigned int> n_accepted_per_meas(
126-
n_accepted_tracks_per_measurement.at(
127-
static_cast<unsigned int>(unique_meas_idx)));
128-
const unsigned int N_A = n_accepted_per_meas.fetch_sub(n_sharing_tracks);
131+
track_status[worst_idx] = 0;
129132

130-
if (N_A != 1 + n_sharing_tracks) {
131-
return;
132-
}
133+
} else if (sh_meas_ids[i] != id) {
134+
break;
135+
}
136+
}
133137

134-
const unsigned int alive_idx =
135-
thrust::find(thrust::seq, track_status.begin(), track_status.end(), 1) -
136-
track_status.begin();
138+
vecmem::device_atomic_ref<unsigned int> n_accepted_per_meas(
139+
n_accepted_tracks_per_measurement.at(
140+
static_cast<unsigned int>(unique_meas_idx)));
141+
const unsigned int N_A =
142+
n_accepted_per_meas.fetch_sub(n_sharing_tracks);
137143

138-
shared_tids[threadIndex] = static_cast<unsigned int>(tracks[alive_idx]);
144+
if (N_A == 1 + n_sharing_tracks) {
145+
active = true;
146+
const unsigned int alive_idx =
147+
thrust::find(thrust::seq, track_status.begin(),
148+
track_status.end(), 1) -
149+
track_status.begin();
139150

140-
auto tid = shared_tids[threadIndex];
151+
shared_tids[threadIndex] =
152+
static_cast<unsigned int>(tracks[alive_idx]);
141153

142-
const auto m_count = static_cast<unsigned int>(thrust::count(
143-
thrust::seq, meas_ids[tid].begin(), meas_ids[tid].end(), id));
154+
auto tid = shared_tids[threadIndex];
144155

145-
const unsigned int N_S =
146-
vecmem::device_atomic_ref<unsigned int>(n_shared.at(tid))
147-
.fetch_sub(m_count);
156+
const auto m_count = static_cast<unsigned int>(thrust::count(
157+
thrust::seq, meas_ids[tid].begin(), meas_ids[tid].end(), id));
158+
159+
const unsigned int N_S =
160+
vecmem::device_atomic_ref<unsigned int>(n_shared.at(tid))
161+
.fetch_sub(m_count);
162+
}
163+
}
148164

149165
__syncthreads();
150166

151-
bool already_pushed = false;
152-
for (unsigned int i = 0; i < threadIndex; ++i) {
153-
if (shared_tids[i] == tid) {
154-
already_pushed = true;
155-
break;
167+
if (active) {
168+
auto tid = shared_tids[threadIndex];
169+
bool already_pushed = false;
170+
for (unsigned int i = 0; i < threadIndex; ++i) {
171+
if (shared_tids[i] == tid) {
172+
already_pushed = true;
173+
break;
174+
}
156175
}
157-
}
158-
if (!already_pushed) {
176+
if (!already_pushed) {
159177

160-
// Write updated track IDs
161-
vecmem::device_atomic_ref<unsigned int> num_updated_tracks(
162-
*(payload.n_updated_tracks));
178+
// Write updated track IDs
179+
vecmem::device_atomic_ref<unsigned int> num_updated_tracks(
180+
*(payload.n_updated_tracks));
163181

164-
const unsigned int pos = num_updated_tracks.fetch_add(1);
182+
const unsigned int pos = num_updated_tracks.fetch_add(1);
165183

166-
updated_tracks[pos] = tid;
167-
is_updated[tid] = 1;
184+
updated_tracks[pos] = tid;
185+
is_updated[tid] = 1;
168186

169-
rel_shared.at(tid) = static_cast<traccc::scalar>(n_shared.at(tid)) /
170-
static_cast<traccc::scalar>(n_meas.at(tid));
187+
rel_shared.at(tid) = static_cast<traccc::scalar>(n_shared.at(tid)) /
188+
static_cast<traccc::scalar>(n_meas.at(tid));
189+
}
171190
}
172191
}
173192

0 commit comments

Comments
 (0)