diff --git a/benchmarks/cpu/toy_detector_cpu.cpp b/benchmarks/cpu/toy_detector_cpu.cpp index b9203f4455..bf3d40dd31 100644 --- a/benchmarks/cpu/toy_detector_cpu.cpp +++ b/benchmarks/cpu/toy_detector_cpu.cpp @@ -51,7 +51,9 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { // Algorithms traccc::host::seeding_algorithm sa(seeding_cfg, grid_cfg, filter_cfg, host_mr); - traccc::host::track_params_estimation tp(host_mr); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr); traccc::host::combinatorial_kalman_filter_algorithm host_finding( finding_cfg, host_mr); traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg, host_mr, diff --git a/benchmarks/cuda/toy_detector_cuda.cpp b/benchmarks/cuda/toy_detector_cuda.cpp index 3f8ad3b808..316cdf4ee3 100644 --- a/benchmarks/cuda/toy_detector_cuda.cpp +++ b/benchmarks/cuda/toy_detector_cuda.cpp @@ -61,7 +61,9 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) { // Algorithms traccc::cuda::seeding_algorithm sa_cuda(seeding_cfg, grid_cfg, filter_cfg, mr, async_copy, stream); - traccc::cuda::track_params_estimation tp_cuda(mr, async_copy, stream); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::cuda::track_params_estimation tp_cuda( + track_params_estimation_config, mr, async_copy, stream); traccc::cuda::combinatorial_kalman_filter_algorithm device_finding( finding_cfg, mr, async_copy, stream); traccc::cuda::kalman_fitting_algorithm device_fitting(fitting_cfg, mr, diff --git a/core/include/traccc/seeding/detail/track_params_estimation_config.hpp b/core/include/traccc/seeding/detail/track_params_estimation_config.hpp new file mode 100644 index 0000000000..377050ca36 --- /dev/null +++ b/core/include/traccc/seeding/detail/track_params_estimation_config.hpp @@ -0,0 +1,35 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2021-2022 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +#include + +#include "traccc/definitions/common.hpp" +#include "traccc/definitions/primitives.hpp" +#include "traccc/definitions/track_parametrization.hpp" + +namespace traccc { + +struct track_params_estimation_config { + std::array initial_sigma = { + 1.f * unit::mm, + 1.f * unit::mm, + 1.f * unit::degree, + 1.f * unit::degree, + 0.f * unit::e / unit::GeV, + 1.f * unit::ns}; + + scalar initial_sigma_qopt = 0.1f * unit::e / unit::GeV; + + scalar initial_sigma_pt_rel = 0.1f; + + std::array initial_inflation = {1.f, 1.f, 1.f, + 1.f, 1.f, 100.f}; +}; + +} // namespace traccc diff --git a/core/include/traccc/seeding/track_params_estimation.hpp b/core/include/traccc/seeding/track_params_estimation.hpp index a40ef1fb6c..5c662b7b2a 100644 --- a/core/include/traccc/seeding/track_params_estimation.hpp +++ b/core/include/traccc/seeding/track_params_estimation.hpp @@ -12,6 +12,7 @@ #include "traccc/edm/seed_collection.hpp" #include "traccc/edm/spacepoint_collection.hpp" #include "traccc/edm/track_parameters.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" #include "traccc/utils/algorithm.hpp" #include "traccc/utils/messaging.hpp" @@ -31,8 +32,7 @@ class track_params_estimation : public algorithm&)>, + const edm::seed_collection::const_view&, const vector3&)>, public messaging { public: @@ -40,6 +40,7 @@ class track_params_estimation /// /// @param mr is the memory resource track_params_estimation( + const track_params_estimation_config& config, vecmem::memory_resource& mr, std::unique_ptr logger = getDummyLogger().clone()); @@ -56,16 +57,11 @@ class track_params_estimation output_type operator()( const measurement_collection_types::const_view& measurements, const edm::spacepoint_collection::const_view& spacepoints, - const edm::seed_collection::const_view& seeds, const vector3& bfield, - const std::array& stddev = { - 0.02f * traccc::unit::mm, - 0.03f * traccc::unit::mm, - 1.f * traccc::unit::degree, - 1.f * traccc::unit::degree, - 0.01f / traccc::unit::GeV, - 1.f * traccc::unit::ns}) const override; + const edm::seed_collection::const_view& seeds, + const vector3& bfield) const override; private: + const track_params_estimation_config m_config; /// The memory resource to use in the algorithm std::reference_wrapper m_mr; }; // class track_params_estimation diff --git a/core/src/seeding/track_params_estimation.cpp b/core/src/seeding/track_params_estimation.cpp index 017b8741a7..ddec3fb1a1 100644 --- a/core/src/seeding/track_params_estimation.cpp +++ b/core/src/seeding/track_params_estimation.cpp @@ -16,14 +16,15 @@ namespace traccc::host { track_params_estimation::track_params_estimation( - vecmem::memory_resource& mr, std::unique_ptr logger) - : messaging(std::move(logger)), m_mr(mr) {} + const track_params_estimation_config& config, vecmem::memory_resource& mr, + std::unique_ptr logger) + : messaging(std::move(logger)), m_config(config), m_mr(mr) {} track_params_estimation::output_type track_params_estimation::operator()( const measurement_collection_types::const_view& measurements_view, const edm::spacepoint_collection::const_view& spacepoints_view, - const edm::seed_collection::const_view& seeds_view, const vector3& bfield, - const std::array& stddev) const { + const edm::seed_collection::const_view& seeds_view, + const vector3& bfield) const { // Set up the input / output objects. const measurement_collection_types::const_device measurements( @@ -61,8 +62,27 @@ track_params_estimation::output_type track_params_estimation::operator()( // Set Covariance for (std::size_t j = 0; j < e_bound_size; ++j) { - getter::element(track_params.covariance(), j, j) = - stddev[j] * stddev[j]; + scalar var = m_config.initial_sigma[i] * m_config.initial_sigma[i]; + + if (i == e_bound_qoverp) { + scalar var_theta = getter::element( + track_params.covariance(), e_bound_theta, e_bound_theta); + + var += math::pow(m_config.initial_sigma_qopt * + math::sin(track_params[e_bound_theta]), + 2.f); + var += math::pow(m_config.initial_sigma_pt_rel * + track_params[e_bound_qoverp], + 2.f); + var += var_theta * + math::pow(track_params[e_bound_qoverp] / + math::tan(track_params[e_bound_theta]), + 2.f); + } + + var *= m_config.initial_inflation[i]; + + getter::element(track_params.covariance(), i, i) = var; } } diff --git a/device/common/include/traccc/seeding/device/estimate_track_params.hpp b/device/common/include/traccc/seeding/device/estimate_track_params.hpp index 18f009dd99..c4d1b63e91 100644 --- a/device/common/include/traccc/seeding/device/estimate_track_params.hpp +++ b/device/common/include/traccc/seeding/device/estimate_track_params.hpp @@ -15,6 +15,7 @@ #include "traccc/edm/seed_collection.hpp" #include "traccc/edm/spacepoint_collection.hpp" #include "traccc/edm/track_parameters.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" namespace traccc::device { @@ -30,11 +31,10 @@ namespace traccc::device { /// TRACCC_HOST_DEVICE inline void estimate_track_params( - global_index_t globalIndex, + global_index_t globalIndex, const track_params_estimation_config& config, const measurement_collection_types::const_view& measurements_view, const edm::spacepoint_collection::const_view& spacepoints_view, const edm::seed_collection::const_view& seeds_view, const vector3& bfield, - const std::array& stddev, bound_track_parameters_collection_types::view params_view); } // namespace traccc::device diff --git a/device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp b/device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp index 4b1392c92b..0beede4037 100644 --- a/device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp +++ b/device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "traccc/seeding/detail/track_params_estimation_config.hpp" #include "traccc/seeding/device/estimate_track_params.hpp" #include "traccc/seeding/track_params_estimation_helper.hpp" @@ -19,10 +20,10 @@ namespace traccc::device { TRACCC_HOST_DEVICE inline void estimate_track_params( const global_index_t globalIndex, + const track_params_estimation_config& config, const measurement_collection_types::const_view& measurements_view, const edm::spacepoint_collection::const_view& spacepoints_view, const edm::seed_collection::const_view& seeds_view, const vector3& bfield, - const std::array& stddev, bound_track_parameters_collection_types::view params_view) { // Check if anything needs to be done. @@ -46,10 +47,37 @@ inline void estimate_track_params( seed_to_bound_param_vector(track_params, measurements_device, spacepoints_device, this_seed, bfield); + // NOTE: The code below uses the covariance of theta in the calculation of + // the calculation of q/p. Thus, theta must be computed first. + static_assert(e_bound_qoverp > e_bound_theta); + // Set Covariance for (std::size_t i = 0; i < e_bound_size; i++) { - getter::element(track_params.covariance(), i, i) = - stddev[i] * stddev[i]; + scalar var = config.initial_sigma[i] * config.initial_sigma[i]; + + if (i == e_bound_qoverp) { + const scalar var_theta = getter::element( + track_params.covariance(), e_bound_theta, e_bound_theta); + + // Contribution from sigma(q/pt) + const scalar sigma_qopt = config.initial_sigma_qopt * + math::sin(track_params[e_bound_theta]); + var += sigma_qopt * sigma_qopt; + + // Contribution from sigma(pt)/pt + const scalar sigma_pt_rel = + config.initial_sigma_pt_rel * track_params[e_bound_qoverp]; + var += sigma_pt_rel * sigma_pt_rel; + + // Contribution from sigma(theta) + scalar sigma_theta = track_params[e_bound_qoverp] / + math::tan(track_params[e_bound_theta]); + var += var_theta * sigma_theta * sigma_theta; + } + + var *= config.initial_inflation[i]; + + getter::element(track_params.covariance(), i, i) = var; } } diff --git a/device/cuda/include/traccc/cuda/seeding/track_params_estimation.hpp b/device/cuda/include/traccc/cuda/seeding/track_params_estimation.hpp index 92d88f7e97..75e70fe7bf 100644 --- a/device/cuda/include/traccc/cuda/seeding/track_params_estimation.hpp +++ b/device/cuda/include/traccc/cuda/seeding/track_params_estimation.hpp @@ -13,6 +13,7 @@ #include "traccc/edm/seed_collection.hpp" #include "traccc/edm/spacepoint_collection.hpp" #include "traccc/edm/track_parameters.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" #include "traccc/utils/algorithm.hpp" #include "traccc/utils/memory_resource.hpp" #include "traccc/utils/messaging.hpp" @@ -31,8 +32,7 @@ struct track_params_estimation : public algorithm&)>, + const edm::seed_collection::const_view&, const vector3&)>, public messaging { public: @@ -43,6 +43,7 @@ struct track_params_estimation /// and host memory blocks /// @param str The CUDA stream to perform the operations in track_params_estimation( + const track_params_estimation_config& config, const traccc::memory_resource& mr, vecmem::copy& copy, stream& str, std::unique_ptr logger = getDummyLogger().clone()); @@ -59,16 +60,12 @@ struct track_params_estimation output_type operator()( const measurement_collection_types::const_view& measurements, const edm::spacepoint_collection::const_view& spacepoints, - const edm::seed_collection::const_view& seeds, const vector3& bfield, - const std::array& = { - 0.02f * traccc::unit::mm, - 0.03f * traccc::unit::mm, - 1.f * traccc::unit::degree, - 1.f * traccc::unit::degree, - 0.01f / traccc::unit::GeV, - 1.f * traccc::unit::ns}) const override; + const edm::seed_collection::const_view& seeds, + const vector3& bfield) const override; private: + const track_params_estimation_config m_config; + /// Memory resource used by the algorithm traccc::memory_resource m_mr; /// The copy object to use diff --git a/device/cuda/src/seeding/track_params_estimation.cu b/device/cuda/src/seeding/track_params_estimation.cu index c275ab3bc0..909ef9f8bb 100644 --- a/device/cuda/src/seeding/track_params_estimation.cu +++ b/device/cuda/src/seeding/track_params_estimation.cu @@ -23,22 +23,24 @@ namespace traccc::cuda { namespace kernels { /// CUDA kernel for running @c traccc::device::estimate_track_params __global__ void estimate_track_params( + const track_params_estimation_config config, measurement_collection_types::const_view measurements_view, edm::spacepoint_collection::const_view spacepoints_view, edm::seed_collection::const_view seed_view, const vector3 bfield, - const std::array stddev, bound_track_parameters_collection_types::view params_view) { - device::estimate_track_params(details::global_index1(), measurements_view, - spacepoints_view, seed_view, bfield, stddev, - params_view); + device::estimate_track_params(details::global_index1(), config, + measurements_view, spacepoints_view, + seed_view, bfield, params_view); } } // namespace kernels track_params_estimation::track_params_estimation( + const track_params_estimation_config& config, const traccc::memory_resource& mr, vecmem::copy& copy, stream& str, std::unique_ptr logger) : messaging(std::move(logger)), + m_config(config), m_mr(mr), m_copy(copy), m_stream(str), @@ -47,8 +49,8 @@ track_params_estimation::track_params_estimation( track_params_estimation::output_type track_params_estimation::operator()( const measurement_collection_types::const_view& measurements_view, const edm::spacepoint_collection::const_view& spacepoints_view, - const edm::seed_collection::const_view& seeds_view, const vector3& bfield, - const std::array& stddev) const { + const edm::seed_collection::const_view& seeds_view, + const vector3& bfield) const { // Get a convenience variable for the stream that we'll be using. cudaStream_t stream = details::get_stream(m_stream); @@ -82,7 +84,7 @@ track_params_estimation::output_type track_params_estimation::operator()( // run the kernel kernels::estimate_track_params<<>>( - measurements_view, spacepoints_view, seeds_view, bfield, stddev, + m_config, measurements_view, spacepoints_view, seeds_view, bfield, params_buffer); TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); diff --git a/examples/run/common/throughput_mt.ipp b/examples/run/common/throughput_mt.ipp index 952a5912a3..c72fa10255 100644 --- a/examples/run/common/throughput_mt.ipp +++ b/examples/run/common/throughput_mt.ipp @@ -13,6 +13,7 @@ // Project include(s) #include "traccc/geometry/detector.hpp" #include "traccc/geometry/host_detector.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" // Command line option include(s). #include "traccc/options/clusterization.hpp" @@ -38,6 +39,7 @@ #include "traccc/performance/throughput.hpp" #include "traccc/performance/timer.hpp" #include "traccc/performance/timing_info.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" // VecMem include(s). #include @@ -146,6 +148,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) { const traccc::seedfinder_config seedfinder_config(seeding_opts); const traccc::seedfilter_config seedfilter_config(seeding_opts); const traccc::spacepoint_grid_config spacepoint_grid_config(seeding_opts); + const traccc::track_params_estimation_config track_params_estimation_config; detray::propagation::config propagation_config(propagation_opts); typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg( finding_opts); @@ -159,10 +162,10 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) { std::vector algs; algs.reserve(threading_opts.threads + 1); for (std::size_t i = 0; i < threading_opts.threads + 1; ++i) { - algs.push_back({host_mr, clustering_cfg, seedfinder_config, - spacepoint_grid_config, seedfilter_config, finding_cfg, - fitting_cfg, det_descr, field, &detector, - logger().clone()}); + algs.push_back( + {host_mr, clustering_cfg, seedfinder_config, spacepoint_grid_config, + seedfilter_config, track_params_estimation_config, finding_cfg, + fitting_cfg, det_descr, field, &detector, logger().clone()}); } // Set up a lambda that calls the correct function on the algorithms. diff --git a/examples/run/common/throughput_st.ipp b/examples/run/common/throughput_st.ipp index 3eeed4b880..7fec0685e9 100644 --- a/examples/run/common/throughput_st.ipp +++ b/examples/run/common/throughput_st.ipp @@ -13,6 +13,7 @@ // Project include(s) #include "traccc/geometry/detector.hpp" #include "traccc/geometry/host_detector.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" // Command line option include(s). #include "traccc/options/clusterization.hpp" @@ -129,6 +130,8 @@ int throughput_st(std::string_view description, int argc, char* argv[]) { const traccc::seedfilter_config seedfilter_config(seeding_opts); const traccc::spacepoint_grid_config spacepoint_grid_config(seeding_opts); + const traccc::track_params_estimation_config track_params_estimation_config; + typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg( finding_opts); finding_cfg.propagation = propagation_config; @@ -140,8 +143,9 @@ int throughput_st(std::string_view description, int argc, char* argv[]) { // Set up the full-chain algorithm. std::unique_ptr alg = std::make_unique( host_mr, clustering_cfg, seedfinder_config, spacepoint_grid_config, - seedfilter_config, finding_cfg, fitting_cfg, det_descr, field, - &detector, logger().clone("FullChainAlg")); + seedfilter_config, track_params_estimation_config, finding_cfg, + fitting_cfg, det_descr, field, &detector, + logger().clone("FullChainAlg")); // Seed the random number generator. if (throughput_opts.random_seed == 0) { diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index dcd99c05d5..7a0b366a6c 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -15,6 +15,7 @@ full_chain_algorithm::full_chain_algorithm( const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, + const track_params_estimation_config& track_params_estimation_config, const finding_algorithm::config_type& finding_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, @@ -31,7 +32,7 @@ full_chain_algorithm::full_chain_algorithm( m_spacepoint_formation(mr, logger->cloneWithSuffix("SpFormationAlg")), m_seeding(finder_config, grid_config, filter_config, mr, logger->cloneWithSuffix("SeedingAlg")), - m_track_parameter_estimation(mr, + m_track_parameter_estimation(track_params_estimation_config, mr, logger->cloneWithSuffix("TrackParamEstAlg")), m_finding(finding_config, mr, logger->cloneWithSuffix("TrackFindingAlg")), m_fitting(fitting_config, mr, *m_copy, @@ -39,6 +40,7 @@ full_chain_algorithm::full_chain_algorithm( m_finder_config(finder_config), m_grid_config(grid_config), m_filter_config(filter_config), + m_track_params_estimation_config(track_params_estimation_config), m_finding_config(finding_config), m_fitting_config(fitting_config) {} diff --git a/examples/run/cpu/full_chain_algorithm.hpp b/examples/run/cpu/full_chain_algorithm.hpp index a34bac892b..ef8881268f 100644 --- a/examples/run/cpu/full_chain_algorithm.hpp +++ b/examples/run/cpu/full_chain_algorithm.hpp @@ -68,17 +68,18 @@ class full_chain_algorithm /// @param dummy This is not used anywhere. Allows templating CPU/Device /// algorithm. /// - full_chain_algorithm(vecmem::memory_resource& mr, - const clustering_algorithm::config_type& dummy, - const seedfinder_config& finder_config, - const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config, - const finding_algorithm::config_type& finding_config, - const fitting_algorithm::config_type& fitting_config, - const silicon_detector_description::host& det_descr, - const magnetic_field& field, - const host_detector* detector, - std::unique_ptr logger); + full_chain_algorithm( + vecmem::memory_resource& mr, + const clustering_algorithm::config_type& dummy, + const seedfinder_config& finder_config, + const spacepoint_grid_config& grid_config, + const seedfilter_config& filter_config, + const track_params_estimation_config& track_params_estimation_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + const silicon_detector_description::host& det_descr, + const magnetic_field& field, const host_detector* detector, + std::unique_ptr logger); /// Reconstruct track parameters in the entire detector /// @@ -141,6 +142,9 @@ class full_chain_algorithm /// Configuration for the seed filtering seedfilter_config m_filter_config; + /// Configuration for track parameter estimation + track_params_estimation_config m_track_params_estimation_config; + /// Configuration for the track finding finding_algorithm::config_type m_finding_config; /// Configuration for the track fitting diff --git a/examples/run/cpu/seeding_example.cpp b/examples/run/cpu/seeding_example.cpp index 811af8f85e..74d2444e27 100644 --- a/examples/run/cpu/seeding_example.cpp +++ b/examples/run/cpu/seeding_example.cpp @@ -129,7 +129,9 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::host::seeding_algorithm sa( seedfinder_config, spacepoint_grid_config, seedfilter_config, host_mr, logger().clone("SeedingAlg")); - traccc::host::track_params_estimation tp(host_mr, + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr, logger().clone("TrackParEstAlg")); // Propagation configuration diff --git a/examples/run/cpu/seq_example.cpp b/examples/run/cpu/seq_example.cpp index a5d38a415b..26c2982d27 100644 --- a/examples/run/cpu/seq_example.cpp +++ b/examples/run/cpu/seq_example.cpp @@ -8,6 +8,7 @@ // core #include "traccc/geometry/detector.hpp" #include "traccc/geometry/host_detector.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" #include "traccc/utils/memory_resource.hpp" #include "traccc/utils/propagation.hpp" @@ -148,7 +149,9 @@ int seq_run(const traccc::opts::input_data& input_opts, traccc::host::seeding_algorithm sa( seedfinder_config, spacepoint_grid_config, seedfilter_config, host_mr, logger().clone("SeedingAlg")); - traccc::host::track_params_estimation tp(host_mr, + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr, logger().clone("TrackParEstAlg")); finding_algorithm finding_alg(finding_cfg, host_mr, diff --git a/examples/run/cuda/full_chain_algorithm.cpp b/examples/run/cuda/full_chain_algorithm.cpp index 76b81bffa3..55a807515c 100644 --- a/examples/run/cuda/full_chain_algorithm.cpp +++ b/examples/run/cuda/full_chain_algorithm.cpp @@ -10,6 +10,7 @@ // Project include(s). #include "traccc/cuda/utils/make_magnetic_field.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" // CUDA include(s). #include @@ -36,6 +37,7 @@ full_chain_algorithm::full_chain_algorithm( const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, const seedfilter_config& filter_config, + const track_params_estimation_config& track_params_estimation_config, const finding_algorithm::config_type& finding_config, const fitting_algorithm::config_type& fitting_config, const silicon_detector_description::host& det_descr, @@ -69,6 +71,7 @@ full_chain_algorithm::full_chain_algorithm( {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, logger->cloneWithSuffix("SeedingAlg")), m_track_parameter_estimation( + track_params_estimation_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, logger->cloneWithSuffix("TrackParEstAlg")), m_finding(finding_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, @@ -79,6 +82,7 @@ full_chain_algorithm::full_chain_algorithm( m_finder_config(finder_config), m_grid_config(grid_config), m_filter_config(filter_config), + m_track_params_estimation_config(track_params_estimation_config), m_finding_config(finding_config), m_fitting_config(fitting_config) { @@ -129,6 +133,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, parent.logger().cloneWithSuffix("SeedingAlg")), m_track_parameter_estimation( + parent.m_track_params_estimation_config, {m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream, parent.logger().cloneWithSuffix("TrackParamEstAlg")), m_finding(parent.m_finding_config, @@ -141,6 +146,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) m_finder_config(parent.m_finder_config), m_grid_config(parent.m_grid_config), m_filter_config(parent.m_filter_config), + m_track_params_estimation_config(parent.m_track_params_estimation_config), m_finding_config(parent.m_finding_config), m_fitting_config(parent.m_fitting_config) { diff --git a/examples/run/cuda/full_chain_algorithm.hpp b/examples/run/cuda/full_chain_algorithm.hpp index 6462895c66..6b1ca0dbbe 100644 --- a/examples/run/cuda/full_chain_algorithm.hpp +++ b/examples/run/cuda/full_chain_algorithm.hpp @@ -71,16 +71,18 @@ class full_chain_algorithm /// @param mr The memory resource to use for the intermediate and result /// objects /// - full_chain_algorithm(vecmem::memory_resource& host_mr, - const clustering_config& clustering_config, - const seedfinder_config& finder_config, - const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config, - const finding_algorithm::config_type& finding_config, - const fitting_algorithm::config_type& fitting_config, - const silicon_detector_description::host& det_descr, - const magnetic_field& field, host_detector* detector, - std::unique_ptr logger); + full_chain_algorithm( + vecmem::memory_resource& host_mr, + const clustering_config& clustering_config, + const seedfinder_config& finder_config, + const spacepoint_grid_config& grid_config, + const seedfilter_config& filter_config, + const track_params_estimation_config& track_params_estimation_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + const silicon_detector_description::host& det_descr, + const magnetic_field& field, host_detector* detector, + std::unique_ptr logger); /// Copy constructor /// @@ -174,6 +176,9 @@ class full_chain_algorithm /// Configuration for the seed filtering seedfilter_config m_filter_config; + /// Configuration for track parameter estimation + track_params_estimation_config m_track_params_estimation_config; + /// Configuration for the track finding finding_algorithm::config_type m_finding_config; /// Configuration for the track fitting diff --git a/examples/run/cuda/seeding_example_cuda.cpp b/examples/run/cuda/seeding_example_cuda.cpp index 11d77240fb..9f76e21718 100644 --- a/examples/run/cuda/seeding_example_cuda.cpp +++ b/examples/run/cuda/seeding_example_cuda.cpp @@ -45,6 +45,7 @@ #include "traccc/performance/soa_comparator.hpp" #include "traccc/performance/timer.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" +#include "traccc/seeding/detail/track_params_estimation_config.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" #include "traccc/utils/propagation.hpp" @@ -153,8 +154,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::host::seeding_algorithm sa( seedfinder_config, spacepoint_grid_config, seedfilter_config, host_mr, logger().clone("HostSeedingAlg")); + const traccc::track_params_estimation_config track_params_estimation_config; traccc::host::track_params_estimation tp( - host_mr, logger().clone("HostTrackParEstAlg")); + track_params_estimation_config, host_mr, + logger().clone("HostTrackParEstAlg")); traccc::cuda::stream stream; @@ -168,7 +171,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, stream, logger().clone("CudaSeedingAlg")}; traccc::cuda::track_params_estimation tp_cuda{ - mr, async_copy, stream, logger().clone("CudaTrackParEstAlg")}; + track_params_estimation_config, mr, async_copy, stream, + logger().clone("CudaTrackParEstAlg")}; // Propagation configuration detray::propagation::config propagation_config(propagation_opts); diff --git a/examples/run/cuda/seq_example_cuda.cpp b/examples/run/cuda/seq_example_cuda.cpp index 364923b1f5..e6789329da 100644 --- a/examples/run/cuda/seq_example_cuda.cpp +++ b/examples/run/cuda/seq_example_cuda.cpp @@ -174,8 +174,10 @@ int seq_run(const traccc::opts::detector& detector_opts, traccc::host::seeding_algorithm sa( seedfinder_config, spacepoint_grid_config, seedfilter_config, host_mr, logger().clone("HostSeedingAlg")); + traccc::track_params_estimation_config track_params_estimation_config; traccc::host::track_params_estimation tp( - host_mr, logger().clone("HostTrackParEstAlg")); + track_params_estimation_config, host_mr, + logger().clone("HostTrackParEstAlg")); host_finding_algorithm finding_alg(finding_cfg, host_mr, logger().clone("HostFindingAlg")); traccc::host::greedy_ambiguity_resolution_algorithm resolution_alg_cpu( @@ -195,7 +197,8 @@ int seq_run(const traccc::opts::detector& detector_opts, seedfinder_config, spacepoint_grid_config, seedfilter_config, mr, copy, stream, logger().clone("CudaSeedingAlg")); traccc::cuda::track_params_estimation tp_cuda( - mr, copy, stream, logger().clone("CudaTrackParEstAlg")); + track_params_estimation_config, mr, copy, stream, + logger().clone("CudaTrackParEstAlg")); device_finding_algorithm finding_alg_cuda(finding_cfg, mr, copy, stream, logger().clone("CudaFindingAlg")); traccc::cuda::greedy_ambiguity_resolution_algorithm resolution_alg_cuda( diff --git a/tests/cpu/test_seeding.cpp b/tests/cpu/test_seeding.cpp index cf4b099aa8..ca8d825b71 100644 --- a/tests/cpu/test_seeding.cpp +++ b/tests/cpu/test_seeding.cpp @@ -88,7 +88,9 @@ TEST(seeding, case1) { // The number of seeds should be eqaul to one ASSERT_EQ(seeds.size(), 1u); - traccc::host::track_params_estimation tp(host_mr); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr); auto bound_params = tp(vecmem::get_data(measurements), vecmem::get_data(spacepoints), @@ -160,7 +162,9 @@ TEST(seeding, case2) { // The number of seeds should be eqaul to one ASSERT_EQ(seeds.size(), 1u); - traccc::host::track_params_estimation tp(host_mr); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr); auto bound_params = tp(vecmem::get_data(measurements), vecmem::get_data(spacepoints), diff --git a/tests/cpu/test_track_params_estimation.cpp b/tests/cpu/test_track_params_estimation.cpp index b7e18c05ef..4b7d35eb93 100644 --- a/tests/cpu/test_track_params_estimation.cpp +++ b/tests/cpu/test_track_params_estimation.cpp @@ -66,7 +66,9 @@ TEST(track_params_estimation, helix_negative_charge) { seeds.push_back({0, 1, 2}); // Run track parameter estimation - traccc::host::track_params_estimation tp(host_mr); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr); auto bound_params = tp(vecmem::get_data(measurements), vecmem::get_data(spacepoints), vecmem::get_data(seeds), B); @@ -114,7 +116,9 @@ TEST(track_params_estimation, helix_positive_charge) { seeds.push_back({0, 1, 2}); // Run track parameter estimation - traccc::host::track_params_estimation tp(host_mr); + traccc::track_params_estimation_config track_params_estimation_config; + traccc::host::track_params_estimation tp(track_params_estimation_config, + host_mr); auto bound_params = tp(vecmem::get_data(measurements), vecmem::get_data(spacepoints), vecmem::get_data(seeds), B);