diff --git a/examples/run/alpaka/full_chain_algorithm.cpp b/examples/run/alpaka/full_chain_algorithm.cpp index 22c505daae..42eb4a9f08 100644 --- a/examples/run/alpaka/full_chain_algorithm.cpp +++ b/examples/run/alpaka/full_chain_algorithm.cpp @@ -21,6 +21,7 @@ full_chain_algorithm::full_chain_algorithm( const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + [[maybe_unused]] const ambiguity_resolution_config& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, host_detector* detector, diff --git a/examples/run/alpaka/full_chain_algorithm.hpp b/examples/run/alpaka/full_chain_algorithm.hpp index 51668913a2..1f35a72e10 100644 --- a/examples/run/alpaka/full_chain_algorithm.hpp +++ b/examples/run/alpaka/full_chain_algorithm.hpp @@ -17,6 +17,7 @@ #include "traccc/alpaka/seeding/track_params_estimation.hpp" #include "traccc/alpaka/utils/get_device_info.hpp" #include "traccc/alpaka/utils/vecmem_objects.hpp" +#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp" #include "traccc/bfield/magnetic_field.hpp" #include "traccc/clusterization/clustering_config.hpp" #include "traccc/edm/silicon_cell_collection.hpp" @@ -78,6 +79,7 @@ class full_chain_algorithm const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + const ambiguity_resolution_config& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, host_detector* detector, diff --git a/examples/run/common/throughput_mt.ipp b/examples/run/common/throughput_mt.ipp index 19d3fa7829..9696c25d7d 100644 --- a/examples/run/common/throughput_mt.ipp +++ b/examples/run/common/throughput_mt.ipp @@ -25,6 +25,7 @@ #include "traccc/options/track_finding.hpp" #include "traccc/options/track_fitting.hpp" #include "traccc/options/track_propagation.hpp" +#include "traccc/options/track_resolution.hpp" #include "traccc/options/track_seeding.hpp" // I/O include(s). @@ -77,6 +78,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) { opts::track_seeding seeding_opts; opts::track_finding finding_opts; opts::track_propagation propagation_opts; + opts::track_resolution resolution_opts; opts::track_fitting fitting_opts; opts::throughput throughput_opts; opts::threading threading_opts; @@ -147,6 +149,8 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) { finding_opts); finding_cfg.propagation = propagation_config; + ambiguity_resolution_config resolution_cfg(resolution_opts); + typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg( fitting_opts); fitting_cfg.propagation = propagation_config; @@ -157,8 +161,8 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) { for (std::size_t i = 0; i < threading_opts.threads + 1; ++i) { algs.push_back({host_mr, clustering_cfg, seedfinder_config, spacepoint_grid_config, seedfilter_config, finding_cfg, - fitting_cfg, det_descr, field, &detector, - logger().clone()}); + resolution_cfg, fitting_cfg, det_descr, field, + &detector, logger().clone()}); } // Set up a lambda that calls the correct function on the algorithms. diff --git a/examples/run/common/throughput_st.ipp b/examples/run/common/throughput_st.ipp index 67ab2b7367..d840054422 100644 --- a/examples/run/common/throughput_st.ipp +++ b/examples/run/common/throughput_st.ipp @@ -24,6 +24,7 @@ #include "traccc/options/track_finding.hpp" #include "traccc/options/track_fitting.hpp" #include "traccc/options/track_propagation.hpp" +#include "traccc/options/track_resolution.hpp" #include "traccc/options/track_seeding.hpp" // I/O include(s). @@ -66,13 +67,14 @@ int throughput_st(std::string_view description, int argc, char* argv[]) { opts::track_seeding seeding_opts; opts::track_finding finding_opts; opts::track_propagation propagation_opts; + opts::track_resolution resolution_opts; opts::track_fitting fitting_opts; opts::throughput throughput_opts; opts::program_options program_opts{ description, {detector_opts, bfield_opts, input_opts, clusterization_opts, - seeding_opts, finding_opts, propagation_opts, fitting_opts, - throughput_opts}, + seeding_opts, finding_opts, propagation_opts, resolution_opts, + fitting_opts, throughput_opts}, argc, argv, logger->cloneWithSuffix("Options")}; @@ -128,6 +130,8 @@ int throughput_st(std::string_view description, int argc, char* argv[]) { finding_opts); finding_cfg.propagation = propagation_config; + ambiguity_resolution_config resolution_cfg(resolution_opts); + typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg( fitting_opts); fitting_cfg.propagation = propagation_config; @@ -135,8 +139,8 @@ int throughput_st(std::string_view description, int argc, char* argv[]) { // Set up the full-chain algorithm. std::unique_ptr alg = std::make_unique( host_mr, clustering_cfg, seedfinder_config, spacepoint_grid_config, - seedfilter_config, finding_cfg, fitting_cfg, det_descr, field, - &detector, logger->clone("FullChainAlg")); + seedfilter_config, finding_cfg, resolution_cfg, fitting_cfg, det_descr, + field, &detector, logger->clone("FullChainAlg")); // Seed the random number generator. if (throughput_opts.random_seed == 0) { diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index dcd99c05d5..a9f68f2a43 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -16,6 +16,7 @@ full_chain_algorithm::full_chain_algorithm( const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + const ambiguity_solving_algorithm::config_type& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, const host_detector* detector, @@ -34,6 +35,8 @@ full_chain_algorithm::full_chain_algorithm( m_track_parameter_estimation(mr, logger->cloneWithSuffix("TrackParamEstAlg")), m_finding(finding_config, mr, logger->cloneWithSuffix("TrackFindingAlg")), + m_ambiguity_solving(resolution_config, mr, + logger->cloneWithSuffix("AmbiguityResolutionAlg")), m_fitting(fitting_config, mr, *m_copy, logger->cloneWithSuffix("TrackFittingAlg")), m_finder_config(finder_config), @@ -78,10 +81,14 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()( const finding_algorithm::output_type track_candidates = m_finding( *m_detector, m_field, measurements_view, track_params_view); + const ambiguity_solving_algorithm::output_type + resolved_track_candidates = m_ambiguity_solving( + {vecmem::get_data(track_candidates), measurements_view}); + // Run the track fitting, and return its results. - return m_fitting( - *m_detector, m_field, - {vecmem::get_data(track_candidates), measurements_view}) + return m_fitting(*m_detector, m_field, + {vecmem::get_data(resolved_track_candidates), + measurements_view}) .tracks; } // If not, just return an empty object. diff --git a/examples/run/cpu/full_chain_algorithm.hpp b/examples/run/cpu/full_chain_algorithm.hpp index a34bac892b..2be512fa90 100644 --- a/examples/run/cpu/full_chain_algorithm.hpp +++ b/examples/run/cpu/full_chain_algorithm.hpp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp" #include "traccc/bfield/magnetic_field.hpp" #include "traccc/clusterization/clusterization_algorithm.hpp" #include "traccc/edm/silicon_cell_collection.hpp" @@ -56,6 +57,9 @@ class full_chain_algorithm /// Track finding algorithm type using finding_algorithm = traccc::host::combinatorial_kalman_filter_algorithm; + /// Ambiguity solving algorithm type + using ambiguity_solving_algorithm = + traccc::host::greedy_ambiguity_resolution_algorithm; /// Track fitting algorithm type using fitting_algorithm = traccc::host::kalman_fitting_algorithm; @@ -68,17 +72,19 @@ class full_chain_algorithm /// @param dummy This is not used anywhere. Allows templating CPU/Device /// algorithm. /// - full_chain_algorithm(vecmem::memory_resource& mr, - const clustering_algorithm::config_type& dummy, - const seedfinder_config& finder_config, - const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config, - const finding_algorithm::config_type& finding_config, - const fitting_algorithm::config_type& fitting_config, - const silicon_detector_description::host& det_descr, - const magnetic_field& field, - const host_detector* detector, - std::unique_ptr logger); + + full_chain_algorithm( + vecmem::memory_resource& mr, + const clustering_algorithm::config_type& dummy, + const seedfinder_config& finder_config, + const spacepoint_grid_config& grid_config, + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const ambiguity_solving_algorithm::config_type& resolution_config, + const fitting_algorithm::config_type& fitting_config, + const silicon_detector_description::host& det_descr, + const magnetic_field& field, const host_detector* detector, + std::unique_ptr logger); /// Reconstruct track parameters in the entire detector /// @@ -126,6 +132,8 @@ class full_chain_algorithm /// Track finding algorithm finding_algorithm m_finding; + /// Ambiguity solving algorithm + ambiguity_solving_algorithm m_ambiguity_solving; /// Track fitting algorithm fitting_algorithm m_fitting; @@ -143,6 +151,8 @@ class full_chain_algorithm /// Configuration for the track finding finding_algorithm::config_type m_finding_config; + /// Configuration for the ambiguity solving + ambiguity_solving_algorithm::config_type m_resolution_config; /// Configuration for the track fitting fitting_algorithm::config_type m_fitting_config; diff --git a/examples/run/cuda/full_chain_algorithm.cpp b/examples/run/cuda/full_chain_algorithm.cpp index 76b81bffa3..ecd3d1fce3 100644 --- a/examples/run/cuda/full_chain_algorithm.cpp +++ b/examples/run/cuda/full_chain_algorithm.cpp @@ -37,6 +37,7 @@ full_chain_algorithm::full_chain_algorithm( const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + const ambiguity_solving_algorithm::config_type& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, host_detector* detector, @@ -73,6 +74,9 @@ full_chain_algorithm::full_chain_algorithm( logger->cloneWithSuffix("TrackParEstAlg")), m_finding(finding_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, logger->cloneWithSuffix("TrackFindingAlg")), + m_ambiguity_solving( + resolution_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, + m_copy, m_stream, logger->cloneWithSuffix("AmbiguityResolutionAlg")), m_fitting(fitting_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, logger->cloneWithSuffix("TrackFittingAlg")), m_clustering_config(clustering_config), @@ -80,6 +84,7 @@ full_chain_algorithm::full_chain_algorithm( m_grid_config(grid_config), m_filter_config(filter_config), m_finding_config(finding_config), + m_resolution_config(resolution_config), m_fitting_config(fitting_config) { // Tell the user what device is being used. @@ -134,6 +139,10 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) m_finding(parent.m_finding_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, parent.logger().cloneWithSuffix("TrackFindingAlg")), + m_ambiguity_solving( + parent.m_resolution_config, + {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, + parent.logger().cloneWithSuffix("AmbiguityResolutionAlg")), m_fitting(parent.m_fitting_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, parent.logger().cloneWithSuffix("TrackFittingAlg")), @@ -142,6 +151,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) m_grid_config(parent.m_grid_config), m_filter_config(parent.m_filter_config), m_finding_config(parent.m_finding_config), + m_resolution_config(parent.m_resolution_config), m_fitting_config(parent.m_fitting_config) { // Copy the detector (description) to the device. @@ -180,9 +190,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()( const finding_algorithm::output_type track_candidates = m_finding(m_device_detector, m_field, measurements, track_params); + // Run the ambiguity solver (asynchronously). + const ambiguity_solving_algorithm::output_type + resolved_track_candidates = + m_ambiguity_solving({track_candidates, measurements}); + // Run the track fitting (asynchronously). - const fitting_algorithm::output_type track_states = m_fitting( - m_device_detector, m_field, {track_candidates, measurements}); + const fitting_algorithm::output_type track_states = + m_fitting(m_device_detector, m_field, + {resolved_track_candidates, measurements}); // Copy a limited amount of result data back to the host. const auto host_tracks = diff --git a/examples/run/cuda/full_chain_algorithm.hpp b/examples/run/cuda/full_chain_algorithm.hpp index 6462895c66..33eb1060e7 100644 --- a/examples/run/cuda/full_chain_algorithm.hpp +++ b/examples/run/cuda/full_chain_algorithm.hpp @@ -9,6 +9,7 @@ // Project include(s). #include "traccc/clusterization/clustering_config.hpp" +#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp" #include "traccc/cuda/clusterization/clusterization_algorithm.hpp" #include "traccc/cuda/clusterization/measurement_sorting_algorithm.hpp" #include "traccc/cuda/finding/combinatorial_kalman_filter_algorithm.hpp" @@ -61,6 +62,9 @@ class full_chain_algorithm /// Track finding algorithm type using finding_algorithm = traccc::cuda::combinatorial_kalman_filter_algorithm; + /// Ambiguity solving algorithm type + using ambiguity_solving_algorithm = + traccc::cuda::greedy_ambiguity_resolution_algorithm; /// Track fitting algorithm type using fitting_algorithm = traccc::cuda::kalman_fitting_algorithm; @@ -71,16 +75,19 @@ class full_chain_algorithm /// @param mr The memory resource to use for the intermediate and result /// objects /// - full_chain_algorithm(vecmem::memory_resource& host_mr, - const clustering_config& clustering_config, - const seedfinder_config& finder_config, - const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config, - const finding_algorithm::config_type& finding_config, - const fitting_algorithm::config_type& fitting_config, - const silicon_detector_description::host& det_descr, - const magnetic_field& field, host_detector* detector, - std::unique_ptr logger); + + full_chain_algorithm( + vecmem::memory_resource& host_mr, + const clustering_config& clustering_config, + const seedfinder_config& finder_config, + const spacepoint_grid_config& grid_config, + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const ambiguity_solving_algorithm::config_type& resolution_config, + const fitting_algorithm::config_type& fitting_config, + const silicon_detector_description::host& det_descr, + const magnetic_field& field, host_detector* detector, + std::unique_ptr logger); /// Copy constructor /// @@ -157,6 +164,8 @@ class full_chain_algorithm /// Track finding algorithm finding_algorithm m_finding; + /// Ambiguity solving algorithm + ambiguity_solving_algorithm m_ambiguity_solving; /// Track fitting algorithm fitting_algorithm m_fitting; @@ -176,6 +185,8 @@ class full_chain_algorithm /// Configuration for the track finding finding_algorithm::config_type m_finding_config; + /// Configuration for the ambiguity solving + ambiguity_solving_algorithm::config_type m_resolution_config; /// Configuration for the track fitting fitting_algorithm::config_type m_fitting_config; diff --git a/examples/run/sycl/full_chain_algorithm.hpp b/examples/run/sycl/full_chain_algorithm.hpp index 9d5a9b2b03..7e54f85cb7 100644 --- a/examples/run/sycl/full_chain_algorithm.hpp +++ b/examples/run/sycl/full_chain_algorithm.hpp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp" #include "traccc/edm/silicon_cell_collection.hpp" #include "traccc/edm/track_parameters.hpp" #include "traccc/geometry/detector.hpp" @@ -77,6 +78,7 @@ class full_chain_algorithm const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + const ambiguity_resolution_config& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, host_detector* detector, diff --git a/examples/run/sycl/full_chain_algorithm.sycl b/examples/run/sycl/full_chain_algorithm.sycl index c8e1460ca3..4d8f1034fa 100644 --- a/examples/run/sycl/full_chain_algorithm.sycl +++ b/examples/run/sycl/full_chain_algorithm.sycl @@ -58,6 +58,7 @@ full_chain_algorithm::full_chain_algorithm( const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, const finding_algorithm::config_type& finding_config, + [[maybe_unused]] const ambiguity_resolution_config& resolution_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, const magnetic_field& field, host_detector* detector,