Skip to content

Commit 0817b70

Browse files
committed
Add ambiguity resolution to the chain
1 parent 1241c09 commit 0817b70

10 files changed

+90
-32
lines changed

examples/run/alpaka/full_chain_algorithm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ full_chain_algorithm::full_chain_algorithm(
2121
const spacepoint_grid_config& grid_config,
2222
const seedfilter_config& filter_config,
2323
const finding_algorithm::config_type& finding_config,
24+
[[maybe_unused]] const ambiguity_resolution_config& resolution_config,
2425
const fitting_algorithm::config_type& fitting_config,
2526
const silicon_detector_description::host& det_descr,
2627
const magnetic_field& field, host_detector* detector,

examples/run/alpaka/full_chain_algorithm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "traccc/alpaka/seeding/track_params_estimation.hpp"
1818
#include "traccc/alpaka/utils/get_device_info.hpp"
1919
#include "traccc/alpaka/utils/vecmem_objects.hpp"
20+
#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
2021
#include "traccc/bfield/magnetic_field.hpp"
2122
#include "traccc/clusterization/clustering_config.hpp"
2223
#include "traccc/edm/silicon_cell_collection.hpp"
@@ -78,6 +79,7 @@ class full_chain_algorithm
7879
const spacepoint_grid_config& grid_config,
7980
const seedfilter_config& filter_config,
8081
const finding_algorithm::config_type& finding_config,
82+
const ambiguity_resolution_config& resolution_config,
8183
const fitting_algorithm::config_type& fitting_config,
8284
const silicon_detector_description::host& det_descr,
8385
const magnetic_field& field, host_detector* detector,

examples/run/common/throughput_mt.ipp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "traccc/options/track_finding.hpp"
2626
#include "traccc/options/track_fitting.hpp"
2727
#include "traccc/options/track_propagation.hpp"
28+
#include "traccc/options/track_resolution.hpp"
2829
#include "traccc/options/track_seeding.hpp"
2930

3031
// I/O include(s).
@@ -77,6 +78,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) {
7778
opts::track_seeding seeding_opts;
7879
opts::track_finding finding_opts;
7980
opts::track_propagation propagation_opts;
81+
opts::track_resolution resolution_opts;
8082
opts::track_fitting fitting_opts;
8183
opts::throughput throughput_opts;
8284
opts::threading threading_opts;
@@ -147,6 +149,8 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) {
147149
finding_opts);
148150
finding_cfg.propagation = propagation_config;
149151

152+
ambiguity_resolution_config resolution_cfg(resolution_opts);
153+
150154
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
151155
fitting_opts);
152156
fitting_cfg.propagation = propagation_config;
@@ -157,8 +161,8 @@ int throughput_mt(std::string_view description, int argc, char* argv[]) {
157161
for (std::size_t i = 0; i < threading_opts.threads + 1; ++i) {
158162
algs.push_back({host_mr, clustering_cfg, seedfinder_config,
159163
spacepoint_grid_config, seedfilter_config, finding_cfg,
160-
fitting_cfg, det_descr, field, &detector,
161-
logger().clone()});
164+
resolution_cfg, fitting_cfg, det_descr, field,
165+
&detector, logger().clone()});
162166
}
163167

164168
// Set up a lambda that calls the correct function on the algorithms.

examples/run/common/throughput_st.ipp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "traccc/options/track_finding.hpp"
2525
#include "traccc/options/track_fitting.hpp"
2626
#include "traccc/options/track_propagation.hpp"
27+
#include "traccc/options/track_resolution.hpp"
2728
#include "traccc/options/track_seeding.hpp"
2829

2930
// I/O include(s).
@@ -66,13 +67,14 @@ int throughput_st(std::string_view description, int argc, char* argv[]) {
6667
opts::track_seeding seeding_opts;
6768
opts::track_finding finding_opts;
6869
opts::track_propagation propagation_opts;
70+
opts::track_resolution resolution_opts;
6971
opts::track_fitting fitting_opts;
7072
opts::throughput throughput_opts;
7173
opts::program_options program_opts{
7274
description,
7375
{detector_opts, bfield_opts, input_opts, clusterization_opts,
74-
seeding_opts, finding_opts, propagation_opts, fitting_opts,
75-
throughput_opts},
76+
seeding_opts, finding_opts, propagation_opts, resolution_opts,
77+
fitting_opts, throughput_opts},
7678
argc,
7779
argv,
7880
logger->cloneWithSuffix("Options")};
@@ -128,15 +130,17 @@ int throughput_st(std::string_view description, int argc, char* argv[]) {
128130
finding_opts);
129131
finding_cfg.propagation = propagation_config;
130132

133+
ambiguity_resolution_config resolution_cfg(resolution_opts);
134+
131135
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
132136
fitting_opts);
133137
fitting_cfg.propagation = propagation_config;
134138

135139
// Set up the full-chain algorithm.
136140
std::unique_ptr<FULL_CHAIN_ALG> alg = std::make_unique<FULL_CHAIN_ALG>(
137141
host_mr, clustering_cfg, seedfinder_config, spacepoint_grid_config,
138-
seedfilter_config, finding_cfg, fitting_cfg, det_descr, field,
139-
&detector, logger->clone("FullChainAlg"));
142+
seedfilter_config, finding_cfg, resolution_cfg, fitting_cfg, det_descr,
143+
field, &detector, logger->clone("FullChainAlg"));
140144

141145
// Seed the random number generator.
142146
if (throughput_opts.random_seed == 0) {

examples/run/cpu/full_chain_algorithm.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ full_chain_algorithm::full_chain_algorithm(
1616
const spacepoint_grid_config& grid_config,
1717
const seedfilter_config& filter_config,
1818
const finding_algorithm::config_type& finding_config,
19+
const ambiguity_solving_algorithm::config_type& resolution_config,
1920
const fitting_algorithm::config_type& fitting_config,
2021
const silicon_detector_description::host& det_descr,
2122
const magnetic_field& field, const host_detector* detector,
@@ -34,6 +35,8 @@ full_chain_algorithm::full_chain_algorithm(
3435
m_track_parameter_estimation(mr,
3536
logger->cloneWithSuffix("TrackParamEstAlg")),
3637
m_finding(finding_config, mr, logger->cloneWithSuffix("TrackFindingAlg")),
38+
m_ambiguity_solving(resolution_config, mr,
39+
logger->cloneWithSuffix("AmbiguityResolutionAlg")),
3740
m_fitting(fitting_config, mr, *m_copy,
3841
logger->cloneWithSuffix("TrackFittingAlg")),
3942
m_finder_config(finder_config),
@@ -78,10 +81,14 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
7881
const finding_algorithm::output_type track_candidates = m_finding(
7982
*m_detector, m_field, measurements_view, track_params_view);
8083

84+
const ambiguity_solving_algorithm::output_type
85+
resolved_track_candidates = m_ambiguity_solving(
86+
{vecmem::get_data(track_candidates), measurements_view});
87+
8188
// Run the track fitting, and return its results.
82-
return m_fitting(
83-
*m_detector, m_field,
84-
{vecmem::get_data(track_candidates), measurements_view})
89+
return m_fitting(*m_detector, m_field,
90+
{vecmem::get_data(resolved_track_candidates),
91+
measurements_view})
8592
.tracks;
8693
}
8794
// If not, just return an empty object.

examples/run/cpu/full_chain_algorithm.hpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1112
#include "traccc/bfield/magnetic_field.hpp"
1213
#include "traccc/clusterization/clusterization_algorithm.hpp"
1314
#include "traccc/edm/silicon_cell_collection.hpp"
@@ -56,6 +57,9 @@ class full_chain_algorithm
5657
/// Track finding algorithm type
5758
using finding_algorithm =
5859
traccc::host::combinatorial_kalman_filter_algorithm;
60+
/// Ambiguity solving algorithm type
61+
using ambiguity_solving_algorithm =
62+
traccc::host::greedy_ambiguity_resolution_algorithm;
5963
/// Track fitting algorithm type
6064
using fitting_algorithm = traccc::host::kalman_fitting_algorithm;
6165

@@ -68,17 +72,19 @@ class full_chain_algorithm
6872
/// @param dummy This is not used anywhere. Allows templating CPU/Device
6973
/// algorithm.
7074
///
71-
full_chain_algorithm(vecmem::memory_resource& mr,
72-
const clustering_algorithm::config_type& dummy,
73-
const seedfinder_config& finder_config,
74-
const spacepoint_grid_config& grid_config,
75-
const seedfilter_config& filter_config,
76-
const finding_algorithm::config_type& finding_config,
77-
const fitting_algorithm::config_type& fitting_config,
78-
const silicon_detector_description::host& det_descr,
79-
const magnetic_field& field,
80-
const host_detector* detector,
81-
std::unique_ptr<const traccc::Logger> logger);
75+
76+
full_chain_algorithm(
77+
vecmem::memory_resource& mr,
78+
const clustering_algorithm::config_type& dummy,
79+
const seedfinder_config& finder_config,
80+
const spacepoint_grid_config& grid_config,
81+
const seedfilter_config& filter_config,
82+
const finding_algorithm::config_type& finding_config,
83+
const ambiguity_solving_algorithm::config_type& resolution_config,
84+
const fitting_algorithm::config_type& fitting_config,
85+
const silicon_detector_description::host& det_descr,
86+
const magnetic_field& field, const host_detector* detector,
87+
std::unique_ptr<const traccc::Logger> logger);
8288

8389
/// Reconstruct track parameters in the entire detector
8490
///
@@ -126,6 +132,8 @@ class full_chain_algorithm
126132

127133
/// Track finding algorithm
128134
finding_algorithm m_finding;
135+
/// Ambiguity solving algorithm
136+
ambiguity_solving_algorithm m_ambiguity_solving;
129137
/// Track fitting algorithm
130138
fitting_algorithm m_fitting;
131139

@@ -143,6 +151,8 @@ class full_chain_algorithm
143151

144152
/// Configuration for the track finding
145153
finding_algorithm::config_type m_finding_config;
154+
/// Configuration for the ambiguity solving
155+
ambiguity_solving_algorithm::config_type m_resolution_config;
146156
/// Configuration for the track fitting
147157
fitting_algorithm::config_type m_fitting_config;
148158

examples/run/cuda/full_chain_algorithm.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ full_chain_algorithm::full_chain_algorithm(
3737
const spacepoint_grid_config& grid_config,
3838
const seedfilter_config& filter_config,
3939
const finding_algorithm::config_type& finding_config,
40+
const ambiguity_solving_algorithm::config_type& resolution_config,
4041
const fitting_algorithm::config_type& fitting_config,
4142
const silicon_detector_description::host& det_descr,
4243
const magnetic_field& field, host_detector* detector,
@@ -73,13 +74,17 @@ full_chain_algorithm::full_chain_algorithm(
7374
logger->cloneWithSuffix("TrackParEstAlg")),
7475
m_finding(finding_config, {m_cached_device_mr, &m_cached_pinned_host_mr},
7576
m_copy, m_stream, logger->cloneWithSuffix("TrackFindingAlg")),
77+
m_ambiguity_solving(
78+
resolution_config, {m_cached_device_mr, &m_cached_pinned_host_mr},
79+
m_copy, m_stream, logger->cloneWithSuffix("AmbiguityResolutionAlg")),
7680
m_fitting(fitting_config, {m_cached_device_mr, &m_cached_pinned_host_mr},
7781
m_copy, m_stream, logger->cloneWithSuffix("TrackFittingAlg")),
7882
m_clustering_config(clustering_config),
7983
m_finder_config(finder_config),
8084
m_grid_config(grid_config),
8185
m_filter_config(filter_config),
8286
m_finding_config(finding_config),
87+
m_resolution_config(resolution_config),
8388
m_fitting_config(fitting_config) {
8489

8590
// Tell the user what device is being used.
@@ -134,6 +139,10 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
134139
m_finding(parent.m_finding_config,
135140
{m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy,
136141
m_stream, parent.logger().cloneWithSuffix("TrackFindingAlg")),
142+
m_ambiguity_solving(
143+
parent.m_resolution_config,
144+
{m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy, m_stream,
145+
parent.logger().cloneWithSuffix("AmbiguityResolutionAlg")),
137146
m_fitting(parent.m_fitting_config,
138147
{m_cached_device_mr, &m_cached_pinned_host_mr}, m_copy,
139148
m_stream, parent.logger().cloneWithSuffix("TrackFittingAlg")),
@@ -142,6 +151,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
142151
m_grid_config(parent.m_grid_config),
143152
m_filter_config(parent.m_filter_config),
144153
m_finding_config(parent.m_finding_config),
154+
m_resolution_config(parent.m_resolution_config),
145155
m_fitting_config(parent.m_fitting_config) {
146156

147157
// Copy the detector (description) to the device.
@@ -180,9 +190,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
180190
const finding_algorithm::output_type track_candidates =
181191
m_finding(m_device_detector, m_field, measurements, track_params);
182192

193+
// Run the ambiguity solver (asynchronously).
194+
const ambiguity_solving_algorithm::output_type
195+
resolved_track_candidates =
196+
m_ambiguity_solving({track_candidates, measurements});
197+
183198
// Run the track fitting (asynchronously).
184-
const fitting_algorithm::output_type track_states = m_fitting(
185-
m_device_detector, m_field, {track_candidates, measurements});
199+
const fitting_algorithm::output_type track_states =
200+
m_fitting(m_device_detector, m_field,
201+
{resolved_track_candidates, measurements});
186202

187203
// Copy a limited amount of result data back to the host.
188204
const auto host_tracks =

examples/run/cuda/full_chain_algorithm.hpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
// Project include(s).
1111
#include "traccc/clusterization/clustering_config.hpp"
12+
#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1213
#include "traccc/cuda/clusterization/clusterization_algorithm.hpp"
1314
#include "traccc/cuda/clusterization/measurement_sorting_algorithm.hpp"
1415
#include "traccc/cuda/finding/combinatorial_kalman_filter_algorithm.hpp"
@@ -61,6 +62,9 @@ class full_chain_algorithm
6162
/// Track finding algorithm type
6263
using finding_algorithm =
6364
traccc::cuda::combinatorial_kalman_filter_algorithm;
65+
/// Ambiguity solving algorithm type
66+
using ambiguity_solving_algorithm =
67+
traccc::cuda::greedy_ambiguity_resolution_algorithm;
6468
/// Track fitting algorithm type
6569
using fitting_algorithm = traccc::cuda::kalman_fitting_algorithm;
6670

@@ -71,16 +75,19 @@ class full_chain_algorithm
7175
/// @param mr The memory resource to use for the intermediate and result
7276
/// objects
7377
///
74-
full_chain_algorithm(vecmem::memory_resource& host_mr,
75-
const clustering_config& clustering_config,
76-
const seedfinder_config& finder_config,
77-
const spacepoint_grid_config& grid_config,
78-
const seedfilter_config& filter_config,
79-
const finding_algorithm::config_type& finding_config,
80-
const fitting_algorithm::config_type& fitting_config,
81-
const silicon_detector_description::host& det_descr,
82-
const magnetic_field& field, host_detector* detector,
83-
std::unique_ptr<const traccc::Logger> logger);
78+
79+
full_chain_algorithm(
80+
vecmem::memory_resource& host_mr,
81+
const clustering_config& clustering_config,
82+
const seedfinder_config& finder_config,
83+
const spacepoint_grid_config& grid_config,
84+
const seedfilter_config& filter_config,
85+
const finding_algorithm::config_type& finding_config,
86+
const ambiguity_solving_algorithm::config_type& resolution_config,
87+
const fitting_algorithm::config_type& fitting_config,
88+
const silicon_detector_description::host& det_descr,
89+
const magnetic_field& field, host_detector* detector,
90+
std::unique_ptr<const traccc::Logger> logger);
8491

8592
/// Copy constructor
8693
///
@@ -157,6 +164,8 @@ class full_chain_algorithm
157164

158165
/// Track finding algorithm
159166
finding_algorithm m_finding;
167+
/// Ambiguity solving algorithm
168+
ambiguity_solving_algorithm m_ambiguity_solving;
160169
/// Track fitting algorithm
161170
fitting_algorithm m_fitting;
162171

@@ -176,6 +185,8 @@ class full_chain_algorithm
176185

177186
/// Configuration for the track finding
178187
finding_algorithm::config_type m_finding_config;
188+
/// Configuration for the ambiguity solving
189+
ambiguity_solving_algorithm::config_type m_resolution_config;
179190
/// Configuration for the track fitting
180191
fitting_algorithm::config_type m_fitting_config;
181192

examples/run/sycl/full_chain_algorithm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
1112
#include "traccc/edm/silicon_cell_collection.hpp"
1213
#include "traccc/edm/track_parameters.hpp"
1314
#include "traccc/geometry/detector.hpp"
@@ -77,6 +78,7 @@ class full_chain_algorithm
7778
const spacepoint_grid_config& grid_config,
7879
const seedfilter_config& filter_config,
7980
const finding_algorithm::config_type& finding_config,
81+
const ambiguity_resolution_config& resolution_config,
8082
const fitting_algorithm::config_type& fitting_config,
8183
const silicon_detector_description::host& det_descr,
8284
const magnetic_field& field, host_detector* detector,

examples/run/sycl/full_chain_algorithm.sycl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ full_chain_algorithm::full_chain_algorithm(
5858
const spacepoint_grid_config& grid_config,
5959
const seedfilter_config& filter_config,
6060
const finding_algorithm::config_type& finding_config,
61+
[[maybe_unused]] const ambiguity_resolution_config& resolution_config,
6162
const fitting_algorithm::config_type& fitting_config,
6263
const silicon_detector_description::host& det_descr,
6364
const magnetic_field& field, host_detector* detector,

0 commit comments

Comments
 (0)