diff --git a/device/common/include/traccc/finding/device/impl/find_tracks.ipp b/device/common/include/traccc/finding/device/impl/find_tracks.ipp index 801ecd8527..4aee587bc1 100644 --- a/device/common/include/traccc/finding/device/impl/find_tracks.ipp +++ b/device/common/include/traccc/finding/device/impl/find_tracks.ipp @@ -572,15 +572,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * measurements. */ if (local_num_params > 0 || in_param_can_create_hole) { - unsigned int desired_params_to_add = std::max(1u, local_num_params); - - vecmem::device_atomic_ref num_tracks_per_seed( - n_tracks_per_seed.at(seed_idx)); - params_to_add = std::min(desired_params_to_add, - cfg.max_num_branches_per_seed - - std::min(cfg.max_num_branches_per_seed, - num_tracks_per_seed.fetch_add( - desired_params_to_add))); + params_to_add = std::max(1u, local_num_params); local_out_offset = vecmem::device_atomic_ref::max()); if (local_num_params == 0) { - assert(params_to_add <= 1); + assert(params_to_add == 1); - if (in_param_can_create_hole && params_to_add == 1) { + if (in_param_can_create_hole) { const unsigned int out_offset = shared_payload.shared_out_offset + local_out_offset; @@ -640,8 +632,15 @@ TRACCC_HOST_DEVICE inline void find_tracks( const unsigned int n_cands = payload.step - n_skipped; if (n_cands >= cfg.min_track_candidates_per_track) { - auto tip_pos = tips.push_back(prev_link_idx); - tip_lengths.at(tip_pos) = n_cands; + vecmem::device_atomic_ref num_tracks_per_seed( + n_tracks_per_seed.at(seed_idx)); + + auto pos = num_tracks_per_seed.fetch_add(1u); + + if (pos < cfg.max_num_branches_per_seed) { + auto tip_pos = tips.push_back(prev_link_idx); + tip_lengths.at(tip_pos) = n_cands; + } } } } else { @@ -664,8 +663,15 @@ TRACCC_HOST_DEVICE inline void find_tracks( if (last_step && n_cands >= cfg.min_track_candidates_per_track) { - auto tip_pos = tips.push_back(param_pos); - tip_lengths.at(tip_pos) = n_cands; + vecmem::device_atomic_ref num_tracks_per_seed( + n_tracks_per_seed.at(seed_idx)); + + auto pos = num_tracks_per_seed.fetch_add(1u); + + if (pos < cfg.max_num_branches_per_seed) { + auto tip_pos = tips.push_back(param_pos); + tip_lengths.at(tip_pos) = n_cands; + } } } } diff --git a/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp b/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp index a23e555886..f788de91fb 100644 --- a/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp +++ b/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp @@ -29,6 +29,8 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( // Theta id vecmem::device_vector param_ids(payload.param_ids_view); + vecmem::device_vector n_tracks_per_seed( + payload.n_tracks_per_seed_view); const unsigned int param_id = param_ids.at(globalIndex); @@ -58,6 +60,11 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( return; } + if (n_tracks_per_seed.at(link.seed_idx) >= cfg.max_num_branches_per_seed) { + params_liveness.at(param_id) = 0; + return; + } + // Input bound track parameter const bound_track_parameters<> in_par = params.at(param_id); @@ -108,8 +115,15 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( params_liveness[param_id] = 0u; if (n_cands >= cfg.min_track_candidates_per_track) { - auto tip_pos = tips.push_back(link_idx); - tip_lengths.at(tip_pos) = n_cands; + vecmem::device_atomic_ref num_tracks_per_seed( + n_tracks_per_seed.at(link.seed_idx)); + + unsigned int pos = num_tracks_per_seed.fetch_add(1u); + + if (pos < cfg.max_num_branches_per_seed) { + auto tip_pos = tips.push_back(link_idx); + tip_lengths.at(tip_pos) = n_cands; + } } } } diff --git a/device/common/include/traccc/finding/device/propagate_to_next_surface.hpp b/device/common/include/traccc/finding/device/propagate_to_next_surface.hpp index fe827101d3..4d4ff7ccf0 100644 --- a/device/common/include/traccc/finding/device/propagate_to_next_surface.hpp +++ b/device/common/include/traccc/finding/device/propagate_to_next_surface.hpp @@ -79,6 +79,12 @@ struct propagate_to_next_surface_payload { * @brief Vector to hold the number of track states per tip */ vecmem::data::vector_view tip_lengths_view; + + /** + * @brief View object to the vector of the number of tracks per initial + * input seed + */ + vecmem::data::vector_view n_tracks_per_seed_view; }; /// Function for propagating the kalman-updated tracks to the next surface diff --git a/device/cuda/src/finding/combinatorial_kalman_filter.cuh b/device/cuda/src/finding/combinatorial_kalman_filter.cuh index 8aebbe9cfc..96db866208 100644 --- a/device/cuda/src/finding/combinatorial_kalman_filter.cuh +++ b/device/cuda/src/finding/combinatorial_kalman_filter.cuh @@ -166,6 +166,7 @@ combinatorial_kalman_filter( vecmem::data::vector_buffer n_tracks_per_seed_buffer(n_seeds, mr.main); copy.setup(n_tracks_per_seed_buffer)->ignore(); + copy.memset(n_tracks_per_seed_buffer, 0)->ignore(); // Create a buffer for links unsigned int link_buffer_capacity = config.initial_links_per_seed * n_seeds; @@ -228,9 +229,6 @@ combinatorial_kalman_filter( n_max_candidates, mr.main); copy.setup(updated_liveness_buffer)->ignore(); - // Reset the number of tracks per seed - copy.memset(n_tracks_per_seed_buffer, 0)->ignore(); - const unsigned int links_size = copy.get_size(links_buffer); if (links_size + n_max_candidates > link_buffer_capacity) { @@ -434,7 +432,8 @@ combinatorial_kalman_filter( .step = step, .n_in_params = n_candidates, .tips_view = tips_buffer, - .tip_lengths_view = tip_length_buffer}; + .tip_lengths_view = tip_length_buffer, + .n_tracks_per_seed_view = n_tracks_per_seed_buffer}; const unsigned int nThreads = warp_size * 4; const unsigned int nBlocks = diff --git a/device/sycl/src/finding/combinatorial_kalman_filter.hpp b/device/sycl/src/finding/combinatorial_kalman_filter.hpp index ad5eb24cda..5958a74443 100644 --- a/device/sycl/src/finding/combinatorial_kalman_filter.hpp +++ b/device/sycl/src/finding/combinatorial_kalman_filter.hpp @@ -175,6 +175,7 @@ combinatorial_kalman_filter( vecmem::data::vector_buffer n_tracks_per_seed_buffer(n_seeds, mr.main); copy.setup(n_tracks_per_seed_buffer)->wait(); + copy.memset(n_tracks_per_seed_buffer, 0)->wait(); // Create a buffer for links unsigned int link_buffer_capacity = config.initial_links_per_seed * n_seeds; @@ -237,9 +238,6 @@ combinatorial_kalman_filter( n_max_candidates, mr.main); copy.setup(updated_liveness_buffer)->wait(); - // Reset the number of tracks per seed - copy.memset(n_tracks_per_seed_buffer, 0)->wait(); - const unsigned int links_size = copy.get_size(links_buffer); if (links_size + n_max_candidates > link_buffer_capacity) {