@@ -111,6 +111,7 @@ __launch_bounds__(512) __global__
111111 vecmem::device_vector<unsigned int > updated_tracks (
112112 payload.updated_tracks_view );
113113 vecmem::device_vector<int > is_updated (payload.is_updated_view );
114+ vecmem::device_vector<int > track_count (payload.track_count_view );
114115
115116 if (threadIndex == 0 ) {
116117 *(payload.n_removable_tracks ) = 0 ;
@@ -342,6 +343,7 @@ __launch_bounds__(512) __global__
342343
343344 bool active = false ;
344345 unsigned int pos1;
346+ int alive_trk_id = 0 ;
345347
346348 if (!is_duplicate && is_valid_thread) {
347349
@@ -397,46 +399,40 @@ __launch_bounds__(512) __global__
397399 track_status.begin ();
398400
399401 pos1 = atomicAdd (&n_updating_threads, 1 );
402+ alive_trk_id = static_cast <int >(tracks[alive_idx]);
400403
401- sh_buffer[pos1] = static_cast <int >(tracks[alive_idx]);
404+ sh_buffer[pos1] = alive_trk_id;
405+ atomicAdd (&track_count[alive_trk_id], 1 );
402406
403- auto tid = sh_buffer[pos1];
407+ const auto m_count = static_cast <unsigned int >(
408+ thrust::count (thrust::seq, meas_ids[alive_trk_id].begin (),
409+ meas_ids[alive_trk_id].end (), id));
404410
405- const auto m_count = static_cast <unsigned int >(thrust::count (
406- thrust::seq, meas_ids[tid].begin (), meas_ids[tid].end (), id));
407-
408- const unsigned int N_S =
409- vecmem::device_atomic_ref<unsigned int >(n_shared.at (tid))
410- .fetch_sub (m_count);
411+ const unsigned int N_S = vecmem::device_atomic_ref<unsigned int >(
412+ n_shared.at (alive_trk_id))
413+ .fetch_sub (m_count);
411414 }
412415 }
413416
414417 __syncthreads ();
415418
416419 if (active) {
417- auto tid = sh_buffer[pos1];
418- bool already_pushed = false ;
419- for (unsigned int i = 0 ; i < pos1; ++i) {
420- if (sh_buffer[i] == tid) {
421- already_pushed = true ;
422- break ;
423- }
424- }
425420
426- if (!already_pushed) {
421+ auto count = atomicAdd (&track_count[alive_trk_id], -1 );
422+ if (count == 1 ) {
427423
428424 // Write updated track IDs
429425 vecmem::device_atomic_ref<unsigned int > num_updated_tracks (
430426 *(payload.n_updated_tracks ));
431427
432428 const unsigned int pos2 = num_updated_tracks.fetch_add (1 );
433429
434- updated_tracks[pos2] = tid ;
435- is_updated[tid ] = 1 ;
430+ updated_tracks[pos2] = alive_trk_id ;
431+ is_updated[alive_trk_id ] = 1 ;
436432
437- rel_shared.at (tid ) =
438- math::div_ieee754 ( static_cast <traccc::scalar>(n_shared.at (tid )),
439- static_cast <traccc::scalar>(n_meas.at (tid )));
433+ rel_shared.at (alive_trk_id ) = math::div_ieee754 (
434+ static_cast <traccc::scalar>(n_shared.at (alive_trk_id )),
435+ static_cast <traccc::scalar>(n_meas.at (alive_trk_id )));
440436 }
441437 }
442438}
0 commit comments