@@ -38,6 +38,9 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
3838
3939 auto threadIndex = threadIdx .x ;
4040
41+ bool is_valid_thread = false ;
42+ bool is_duplicate = true ;
43+
4144 shared_tids[threadIndex] = std::numeric_limits<unsigned int >::max ();
4245
4346 vecmem::device_vector<const unsigned int > sorted_ids (
@@ -62,112 +65,128 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
6265 payload.meas_to_remove_view );
6366 vecmem::device_vector<unsigned int > threads (payload.threads_view );
6467
65- auto n_accepted_prev = (*payload.n_accepted );
68+ const unsigned n_accepted_prev = *(payload.n_accepted );
69+
70+ __syncthreads ();
71+
6672 if (threadIndex == 0 ) {
6773 (*payload.n_accepted ) -= *(payload.n_removable_tracks );
6874 }
6975
7076 if (threadIndex < *(payload.n_meas_to_remove )) {
7177 sh_meas_ids[threadIndex] = meas_to_remove[threadIndex];
7278 sh_threads[threadIndex] = threads[threadIndex];
73- } else {
74- return ;
79+ is_valid_thread = true ;
7580 }
7681
77- const auto id = sh_meas_ids[threadIndex];
82+ __syncthreads ();
83+
84+ if (is_valid_thread) {
85+
86+ const auto id = sh_meas_ids[threadIndex];
87+ is_duplicate = false ;
7888
79- bool is_duplicate = false ;
80- for ( unsigned int i = 0 ; i < threadIndex; ++i ) {
81- if (sh_meas_ids[i] == id) {
82- is_duplicate = true ;
83- break ;
89+ for ( unsigned int i = 0 ; i < threadIndex; ++i) {
90+ if (sh_meas_ids[i] == id ) {
91+ is_duplicate = true ;
92+ break ;
93+ }
8494 }
8595 }
86- if (is_duplicate) {
87- return ;
88- }
8996
90- const auto unique_meas_idx = meas_id_to_unique_id. at (id) ;
97+ bool active = false ;
9198
92- // If there is only one track associated with measurement, the
93- // number of shared measurement can be reduced by one
94- const auto & tracks = tracks_per_measurement[unique_meas_idx];
95- auto track_status = track_status_per_measurement[unique_meas_idx];
99+ if (!is_duplicate && is_valid_thread) {
96100
97- auto trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[threadIndex]];
101+ const auto id = sh_meas_ids[threadIndex];
102+ const auto unique_meas_idx = meas_id_to_unique_id.at (id);
98103
99- unsigned int worst_idx =
100- thrust::find (thrust::seq, tracks.begin (), tracks.end (), trk_id) -
101- tracks.begin ();
104+ // If there is only one track associated with measurement, the
105+ // number of shared measurement can be reduced by one
106+ const auto & tracks = tracks_per_measurement[unique_meas_idx];
107+ auto track_status = track_status_per_measurement[unique_meas_idx];
102108
103- track_status[worst_idx] = 0 ;
109+ auto trk_id =
110+ sorted_ids.at (n_accepted_prev - 1 - sh_threads[threadIndex]);
104111
105- int n_sharing_tracks = 1 ;
106- for ( unsigned int i = threadIndex + 1 ; i < *(payload. n_meas_to_remove );
107- ++i) {
112+ unsigned int worst_idx =
113+ thrust::find (thrust::seq, tracks. begin (), tracks. end (), trk_id) -
114+ tracks. begin ();
108115
109- if (sh_meas_ids[i] == id && sh_threads[i] != sh_threads[i - 1 ]) {
110- n_sharing_tracks++;
116+ track_status[worst_idx] = 0 ;
111117
112- trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[i]];
118+ int n_sharing_tracks = 1 ;
119+ for (unsigned int i = threadIndex + 1 ; i < *(payload.n_meas_to_remove );
120+ ++i) {
113121
114- worst_idx = thrust::find (thrust::seq, tracks.begin (), tracks.end (),
115- trk_id) -
116- tracks.begin ();
122+ if (sh_meas_ids[i] == id && sh_threads[i] != sh_threads[i - 1 ]) {
123+ n_sharing_tracks++;
117124
118- track_status[worst_idx] = 0 ;
125+ trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[i]] ;
119126
120- } else if (sh_meas_ids[i] != id) {
121- break ;
122- }
123- }
127+ worst_idx = thrust::find (thrust::seq, tracks.begin (),
128+ tracks.end (), trk_id) -
129+ tracks.begin ();
124130
125- vecmem::device_atomic_ref<unsigned int > n_accepted_per_meas (
126- n_accepted_tracks_per_measurement.at (
127- static_cast <unsigned int >(unique_meas_idx)));
128- const unsigned int N_A = n_accepted_per_meas.fetch_sub (n_sharing_tracks);
131+ track_status[worst_idx] = 0 ;
129132
130- if (N_A != 1 + n_sharing_tracks) {
131- return ;
132- }
133+ } else if (sh_meas_ids[i] != id) {
134+ break ;
135+ }
136+ }
133137
134- const unsigned int alive_idx =
135- thrust::find (thrust::seq, track_status.begin (), track_status.end (), 1 ) -
136- track_status.begin ();
138+ vecmem::device_atomic_ref<unsigned int > n_accepted_per_meas (
139+ n_accepted_tracks_per_measurement.at (
140+ static_cast <unsigned int >(unique_meas_idx)));
141+ const unsigned int N_A =
142+ n_accepted_per_meas.fetch_sub (n_sharing_tracks);
137143
138- shared_tids[threadIndex] = static_cast <unsigned int >(tracks[alive_idx]);
144+ if (N_A == 1 + n_sharing_tracks) {
145+ active = true ;
146+ const unsigned int alive_idx =
147+ thrust::find (thrust::seq, track_status.begin (),
148+ track_status.end (), 1 ) -
149+ track_status.begin ();
139150
140- auto tid = shared_tids[threadIndex];
151+ shared_tids[threadIndex] =
152+ static_cast <unsigned int >(tracks[alive_idx]);
141153
142- const auto m_count = static_cast <unsigned int >(thrust::count (
143- thrust::seq, meas_ids[tid].begin (), meas_ids[tid].end (), id));
154+ auto tid = shared_tids[threadIndex];
144155
145- const unsigned int N_S =
146- vecmem::device_atomic_ref<unsigned int >(n_shared.at (tid))
147- .fetch_sub (m_count);
156+ const auto m_count = static_cast <unsigned int >(thrust::count (
157+ thrust::seq, meas_ids[tid].begin (), meas_ids[tid].end (), id));
158+
159+ const unsigned int N_S =
160+ vecmem::device_atomic_ref<unsigned int >(n_shared.at (tid))
161+ .fetch_sub (m_count);
162+ }
163+ }
148164
149165 __syncthreads ();
150166
151- bool already_pushed = false ;
152- for (unsigned int i = 0 ; i < threadIndex; ++i) {
153- if (shared_tids[i] == tid) {
154- already_pushed = true ;
155- break ;
167+ if (active) {
168+ auto tid = shared_tids[threadIndex];
169+ bool already_pushed = false ;
170+ for (unsigned int i = 0 ; i < threadIndex; ++i) {
171+ if (shared_tids[i] == tid) {
172+ already_pushed = true ;
173+ break ;
174+ }
156175 }
157- }
158- if (!already_pushed) {
176+ if (!already_pushed) {
159177
160- // Write updated track IDs
161- vecmem::device_atomic_ref<unsigned int > num_updated_tracks (
162- *(payload.n_updated_tracks ));
178+ // Write updated track IDs
179+ vecmem::device_atomic_ref<unsigned int > num_updated_tracks (
180+ *(payload.n_updated_tracks ));
163181
164- const unsigned int pos = num_updated_tracks.fetch_add (1 );
182+ const unsigned int pos = num_updated_tracks.fetch_add (1 );
165183
166- updated_tracks[pos] = tid;
167- is_updated[tid] = 1 ;
184+ updated_tracks[pos] = tid;
185+ is_updated[tid] = 1 ;
168186
169- rel_shared.at (tid) = static_cast <traccc::scalar>(n_shared.at (tid)) /
170- static_cast <traccc::scalar>(n_meas.at (tid));
187+ rel_shared.at (tid) = static_cast <traccc::scalar>(n_shared.at (tid)) /
188+ static_cast <traccc::scalar>(n_meas.at (tid));
189+ }
171190 }
172191}
173192
0 commit comments