Skip to content

Commit adcb664

Browse files
committed
Add the ambiguity resolution to the full chain algorithm
1 parent f59431e commit adcb664

File tree

11 files changed

+86
-27
lines changed

11 files changed

+86
-27
lines changed

examples/run/alpaka/full_chain_algorithm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ full_chain_algorithm::full_chain_algorithm(
2121
const spacepoint_grid_config& grid_config,
2222
const seedfilter_config& filter_config,
2323
const finding_algorithm::config_type& finding_config,
24+
[[maybe_unused]] const ambiguity_resolution_config& resolution_config,
2425
const fitting_algorithm::config_type& fitting_config,
2526
const silicon_detector_description::host& det_descr, const bfield& field,
2627
host_detector_type* detector, std::unique_ptr<const traccc::Logger> logger)

examples/run/alpaka/full_chain_algorithm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "traccc/alpaka/seeding/track_params_estimation.hpp"
1818
#include "traccc/alpaka/utils/get_device_info.hpp"
1919
#include "traccc/alpaka/utils/vecmem_objects.hpp"
20+
#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
2021
#include "traccc/clusterization/clustering_config.hpp"
2122
#include "traccc/edm/silicon_cell_collection.hpp"
2223
#include "traccc/edm/track_state.hpp"
@@ -81,6 +82,7 @@ class full_chain_algorithm
8182
const spacepoint_grid_config& grid_config,
8283
const seedfilter_config& filter_config,
8384
const finding_algorithm::config_type& finding_config,
85+
const ambiguity_resolution_config& resolution_config,
8486
const fitting_algorithm::config_type& fitting_config,
8587
const silicon_detector_description::host& det_descr,
8688
const bfield& field, host_detector_type* detector,

examples/run/common/throughput_mt.ipp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "traccc/options/track_finding.hpp"
2525
#include "traccc/options/track_fitting.hpp"
2626
#include "traccc/options/track_propagation.hpp"
27+
#include "traccc/options/track_resolution.hpp"
2728
#include "traccc/options/track_seeding.hpp"
2829

2930
// I/O include(s).
@@ -75,6 +76,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
7576
opts::track_seeding seeding_opts;
7677
opts::track_finding finding_opts;
7778
opts::track_propagation propagation_opts;
79+
opts::track_resolution resolution_opts;
7880
opts::track_fitting fitting_opts;
7981
opts::throughput throughput_opts;
8082
opts::threading threading_opts;
@@ -159,6 +161,8 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
159161
finding_opts);
160162
finding_cfg.propagation = propagation_config;
161163

164+
ambiguity_resolution_config resolution_cfg(resolution_opts);
165+
162166
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
163167
fitting_opts);
164168
fitting_cfg.propagation = propagation_config;
@@ -180,6 +184,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
180184
{seeding_opts.seedfinder},
181185
seeding_opts.seedfilter,
182186
finding_cfg,
187+
resolution_cfg,
183188
fitting_cfg,
184189
det_descr,
185190
field,

examples/run/common/throughput_st.ipp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "traccc/options/track_finding.hpp"
2424
#include "traccc/options/track_fitting.hpp"
2525
#include "traccc/options/track_propagation.hpp"
26+
#include "traccc/options/track_resolution.hpp"
2627
#include "traccc/options/track_seeding.hpp"
2728

2829
// I/O include(s).
@@ -64,13 +65,14 @@ int throughput_st(std::string_view description, int argc, char* argv[],
6465
opts::track_seeding seeding_opts;
6566
opts::track_finding finding_opts;
6667
opts::track_propagation propagation_opts;
68+
opts::track_resolution resolution_opts;
6769
opts::track_fitting fitting_opts;
6870
opts::throughput throughput_opts;
6971
opts::program_options program_opts{
7072
description,
7173
{detector_opts, bfield_opts, input_opts, clusterization_opts,
72-
seeding_opts, finding_opts, propagation_opts, fitting_opts,
73-
throughput_opts},
74+
seeding_opts, finding_opts, propagation_opts, resolution_opts,
75+
fitting_opts, throughput_opts},
7476
argc,
7577
argv,
7678
logger->cloneWithSuffix("Options")};
@@ -132,6 +134,8 @@ int throughput_st(std::string_view description, int argc, char* argv[],
132134
finding_opts);
133135
finding_cfg.propagation = propagation_config;
134136

137+
ambiguity_resolution_config resolution_cfg(resolution_opts);
138+
135139
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
136140
fitting_opts);
137141
fitting_cfg.propagation = propagation_config;
@@ -140,7 +144,8 @@ int throughput_st(std::string_view description, int argc, char* argv[],
140144
std::unique_ptr<FULL_CHAIN_ALG> alg = std::make_unique<FULL_CHAIN_ALG>(
141145
alg_host_mr, clustering_cfg, seeding_opts.seedfinder,
142146
spacepoint_grid_config{seeding_opts.seedfinder},
143-
seeding_opts.seedfilter, finding_cfg, fitting_cfg, det_descr, field,
147+
seeding_opts.seedfilter, finding_cfg, resolution_cfg, fitting_cfg,
148+
det_descr, field,
144149
(detector_opts.use_detray_detector ? &detector : nullptr),
145150
logger->clone("FullChainAlg"));
146151

examples/run/cpu/full_chain_algorithm.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ full_chain_algorithm::full_chain_algorithm(
1616
const spacepoint_grid_config& grid_config,
1717
const seedfilter_config& filter_config,
1818
const finding_algorithm::config_type& finding_config,
19+
const ambiguity_solving_algorithm::config_type& resolution_config,
1920
const fitting_algorithm::config_type& fitting_config,
2021
const silicon_detector_description::host& det_descr, const bfield& field,
2122
detector_type* detector, std::unique_ptr<const traccc::Logger> logger)
@@ -32,6 +33,8 @@ full_chain_algorithm::full_chain_algorithm(
3233
m_track_parameter_estimation(mr,
3334
logger->cloneWithSuffix("TrackParamEstAlg")),
3435
m_finding(finding_config, mr, logger->cloneWithSuffix("TrackFindingAlg")),
36+
m_ambiguity_solving(resolution_config, mr,
37+
logger->cloneWithSuffix("AmbiguityResolutionAlg")),
3538
m_fitting(fitting_config, mr, *m_copy,
3639
logger->cloneWithSuffix("TrackFittingAlg")),
3740
m_finder_config(finder_config),
@@ -76,10 +79,14 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
7679
const finding_algorithm::output_type track_candidates = m_finding(
7780
*m_detector, m_field, measurements_view, track_params_view);
7881

82+
const ambiguity_solving_algorithm::output_type
83+
resolved_track_candidates = m_ambiguity_solving(
84+
{vecmem::get_data(track_candidates), measurements_view});
85+
7986
// Run the track fitting, and return its results.
8087
return m_fitting(
8188
*m_detector, m_field,
82-
{vecmem::get_data(track_candidates), measurements_view});
89+
{vecmem::get_data(resolved_track_candidates), measurements_view});
8390
}
8491
// If not, just return an empty object.
8592
else {

examples/run/cpu/full_chain_algorithm.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1112
#include "traccc/clusterization/clusterization_algorithm.hpp"
1213
#include "traccc/edm/silicon_cell_collection.hpp"
1314
#include "traccc/edm/track_state.hpp"
@@ -56,6 +57,9 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
5657
/// Track finding algorithm type
5758
using finding_algorithm =
5859
traccc::host::combinatorial_kalman_filter_algorithm;
60+
/// Ambiguity solving algorithm type
61+
using ambiguity_solving_algorithm =
62+
traccc::host::greedy_ambiguity_resolution_algorithm;
5963
/// Track fitting algorithm type
6064
using fitting_algorithm = traccc::host::kalman_fitting_algorithm;
6165

@@ -68,16 +72,18 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
6872
/// @param dummy This is not used anywhere. Allows templating CPU/Device
6973
/// algorithm.
7074
///
71-
full_chain_algorithm(vecmem::memory_resource& mr,
72-
const clustering_algorithm::config_type& dummy,
73-
const seedfinder_config& finder_config,
74-
const spacepoint_grid_config& grid_config,
75-
const seedfilter_config& filter_config,
76-
const finding_algorithm::config_type& finding_config,
77-
const fitting_algorithm::config_type& fitting_config,
78-
const silicon_detector_description::host& det_descr,
79-
const bfield& field, detector_type* detector,
80-
std::unique_ptr<const traccc::Logger> logger);
75+
full_chain_algorithm(
76+
vecmem::memory_resource& mr,
77+
const clustering_algorithm::config_type& dummy,
78+
const seedfinder_config& finder_config,
79+
const spacepoint_grid_config& grid_config,
80+
const seedfilter_config& filter_config,
81+
const finding_algorithm::config_type& finding_config,
82+
const ambiguity_solving_algorithm::config_type& resolution_config,
83+
const fitting_algorithm::config_type& fitting_config,
84+
const silicon_detector_description::host& det_descr,
85+
const bfield& field, detector_type* detector,
86+
std::unique_ptr<const traccc::Logger> logger);
8187

8288
/// Reconstruct track parameters in the entire detector
8389
///
@@ -115,6 +121,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
115121

116122
/// Track finding algorithm
117123
finding_algorithm m_finding;
124+
/// Ambiguity solving algorithm
125+
ambiguity_solving_algorithm m_ambiguity_solving;
118126
/// Track fitting algorithm
119127
fitting_algorithm m_fitting;
120128

@@ -132,6 +140,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
132140

133141
/// Configuration for the track finding
134142
finding_algorithm::config_type m_finding_config;
143+
/// Configuration for the ambiguity solving
144+
ambiguity_solving_algorithm::config_type m_resolution_config;
135145
/// Configuration for the track fitting
136146
fitting_algorithm::config_type m_fitting_config;
137147

examples/run/cuda/full_chain_algorithm.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ full_chain_algorithm::full_chain_algorithm(
3737
const spacepoint_grid_config& grid_config,
3838
const seedfilter_config& filter_config,
3939
const finding_algorithm::config_type& finding_config,
40+
const ambiguity_solving_algorithm::config_type& resolution_config,
4041
const fitting_algorithm::config_type& fitting_config,
4142
const silicon_detector_description::host& det_descr, const bfield& field,
4243
host_detector_type* detector, std::unique_ptr<const traccc::Logger> logger)
@@ -72,6 +73,9 @@ full_chain_algorithm::full_chain_algorithm(
7273
m_finding(finding_config,
7374
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
7475
m_stream, logger->cloneWithSuffix("TrackFindingAlg")),
76+
m_ambiguity_solving(
77+
resolution_config, memory_resource{*m_cached_device_mr, &m_host_mr},
78+
m_copy, m_stream, logger->cloneWithSuffix("AmbiguityResolutionAlg")),
7579
m_fitting(fitting_config,
7680
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
7781
m_stream, logger->cloneWithSuffix("TrackFittingAlg")),
@@ -80,6 +84,7 @@ full_chain_algorithm::full_chain_algorithm(
8084
m_grid_config(grid_config),
8185
m_filter_config(filter_config),
8286
m_finding_config(finding_config),
87+
m_resolution_config(resolution_config),
8388
m_fitting_config(fitting_config) {
8489

8590
// Tell the user what device is being used.
@@ -134,6 +139,10 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
134139
m_finding(parent.m_finding_config,
135140
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
136141
m_stream, parent.logger().cloneWithSuffix("TrackFindingAlg")),
142+
m_ambiguity_solving(
143+
parent.m_resolution_config,
144+
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream,
145+
parent.logger().cloneWithSuffix("AmbiguityResolutionAlg")),
137146
m_fitting(parent.m_fitting_config,
138147
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
139148
m_stream, parent.logger().cloneWithSuffix("TrackFittingAlg")),
@@ -142,6 +151,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
142151
m_grid_config(parent.m_grid_config),
143152
m_filter_config(parent.m_filter_config),
144153
m_finding_config(parent.m_finding_config),
154+
m_resolution_config(parent.m_resolution_config),
145155
m_fitting_config(parent.m_fitting_config) {
146156

147157
// Copy the detector (description) to the device.
@@ -182,9 +192,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
182192
const finding_algorithm::output_type track_candidates = m_finding(
183193
m_device_detector_view, m_field, measurements, track_params);
184194

195+
// Run the ambiguity solver (asynchronously).
196+
const ambiguity_solving_algorithm::output_type
197+
resolved_track_candidates =
198+
m_ambiguity_solving({track_candidates, measurements});
199+
185200
// Run the track fitting (asynchronously).
186-
const fitting_algorithm::output_type track_states = m_fitting(
187-
m_device_detector_view, m_field, {track_candidates, measurements});
201+
const fitting_algorithm::output_type track_states =
202+
m_fitting(m_device_detector_view, m_field,
203+
{resolved_track_candidates, measurements});
188204

189205
// Copy a limited amount of result data back to the host.
190206
output_type result{&m_host_mr};

examples/run/cuda/full_chain_algorithm.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
// Project include(s).
1111
#include "traccc/clusterization/clustering_config.hpp"
12+
#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1213
#include "traccc/cuda/clusterization/clusterization_algorithm.hpp"
1314
#include "traccc/cuda/clusterization/measurement_sorting_algorithm.hpp"
1415
#include "traccc/cuda/finding/combinatorial_kalman_filter_algorithm.hpp"
@@ -66,6 +67,9 @@ class full_chain_algorithm
6667
/// Track finding algorithm type
6768
using finding_algorithm =
6869
traccc::cuda::combinatorial_kalman_filter_algorithm;
70+
/// Ambiguity solving algorithm type
71+
using ambiguity_solving_algorithm =
72+
traccc::cuda::greedy_ambiguity_resolution_algorithm;
6973
/// Track fitting algorithm type
7074
using fitting_algorithm = traccc::cuda::kalman_fitting_algorithm;
7175

@@ -76,16 +80,18 @@ class full_chain_algorithm
7680
/// @param mr The memory resource to use for the intermediate and result
7781
/// objects
7882
///
79-
full_chain_algorithm(vecmem::memory_resource& host_mr,
80-
const clustering_config& clustering_config,
81-
const seedfinder_config& finder_config,
82-
const spacepoint_grid_config& grid_config,
83-
const seedfilter_config& filter_config,
84-
const finding_algorithm::config_type& finding_config,
85-
const fitting_algorithm::config_type& fitting_config,
86-
const silicon_detector_description::host& det_descr,
87-
const bfield& field, host_detector_type* detector,
88-
std::unique_ptr<const traccc::Logger> logger);
83+
full_chain_algorithm(
84+
vecmem::memory_resource& host_mr,
85+
const clustering_config& clustering_config,
86+
const seedfinder_config& finder_config,
87+
const spacepoint_grid_config& grid_config,
88+
const seedfilter_config& filter_config,
89+
const finding_algorithm::config_type& finding_config,
90+
const ambiguity_solving_algorithm::config_type& resolution_config,
91+
const fitting_algorithm::config_type& fitting_config,
92+
const silicon_detector_description::host& det_descr,
93+
const bfield& field, host_detector_type* detector,
94+
std::unique_ptr<const traccc::Logger> logger);
8995

9096
/// Copy constructor
9197
///
@@ -153,6 +159,8 @@ class full_chain_algorithm
153159

154160
/// Track finding algorithm
155161
finding_algorithm m_finding;
162+
/// Ambiguity solving algorithm
163+
ambiguity_solving_algorithm m_ambiguity_solving;
156164
/// Track fitting algorithm
157165
fitting_algorithm m_fitting;
158166

@@ -172,6 +180,8 @@ class full_chain_algorithm
172180

173181
/// Configuration for the track finding
174182
finding_algorithm::config_type m_finding_config;
183+
/// Configuration for the ambiguity solving
184+
ambiguity_solving_algorithm::config_type m_resolution_config;
175185
/// Configuration for the track fitting
176186
fitting_algorithm::config_type m_fitting_config;
177187

examples/run/sycl/full_chain_algorithm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
1112
#include "traccc/edm/silicon_cell_collection.hpp"
1213
#include "traccc/geometry/detector.hpp"
1314
#include "traccc/geometry/silicon_detector_description.hpp"
@@ -78,6 +79,7 @@ class full_chain_algorithm
7879
const spacepoint_grid_config& grid_config,
7980
const seedfilter_config& filter_config,
8081
const finding_algorithm::config_type& finding_config,
82+
const ambiguity_resolution_config& resolution_config,
8183
const fitting_algorithm::config_type& fitting_config,
8284
const silicon_detector_description::host& det_descr,
8385
const bfield& field, host_detector_type* detector,

examples/run/sycl/full_chain_algorithm.sycl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ full_chain_algorithm::full_chain_algorithm(
5858
const spacepoint_grid_config& grid_config,
5959
const seedfilter_config& filter_config,
6060
const finding_algorithm::config_type& finding_config,
61+
[[maybe_unused]] const ambiguity_resolution_config& resolution_config,
6162
const fitting_algorithm::config_type& fitting_config,
6263
const silicon_detector_description::host& det_descr, const bfield& field,
6364
host_detector_type* detector, std::unique_ptr<const traccc::Logger> logger)

0 commit comments

Comments
 (0)