Skip to content

Commit 022693e

Browse files
committed
Add the greedy ambiguity solver to throughput examples
1 parent 0444ba3 commit 022693e

File tree

6 files changed

+81
-26
lines changed

6 files changed

+81
-26
lines changed

examples/run/common/throughput_mt.ipp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "traccc/options/track_finding.hpp"
2525
#include "traccc/options/track_fitting.hpp"
2626
#include "traccc/options/track_propagation.hpp"
27+
#include "traccc/options/track_resolution.hpp"
2728
#include "traccc/options/track_seeding.hpp"
2829

2930
// I/O include(s).
@@ -75,6 +76,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
7576
opts::track_seeding seeding_opts;
7677
opts::track_finding finding_opts;
7778
opts::track_propagation propagation_opts;
79+
opts::track_resolution resolution_opts;
7880
opts::track_fitting fitting_opts;
7981
opts::throughput throughput_opts;
8082
opts::threading threading_opts;
@@ -159,6 +161,9 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
159161
finding_opts);
160162
finding_cfg.propagation = propagation_config;
161163

164+
typename FULL_CHAIN_ALG::ambiguity_solving_algorithm::config_type
165+
resolution_cfg(resolution_opts);
166+
162167
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
163168
fitting_opts);
164169
fitting_cfg.propagation = propagation_config;
@@ -180,6 +185,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
180185
{seeding_opts.seedfinder},
181186
seeding_opts.seedfilter,
182187
finding_cfg,
188+
resolution_cfg,
183189
fitting_cfg,
184190
det_descr,
185191
field,

examples/run/common/throughput_st.ipp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "traccc/options/track_finding.hpp"
2424
#include "traccc/options/track_fitting.hpp"
2525
#include "traccc/options/track_propagation.hpp"
26+
#include "traccc/options/track_resolution.hpp"
2627
#include "traccc/options/track_seeding.hpp"
2728

2829
// I/O include(s).
@@ -64,13 +65,14 @@ int throughput_st(std::string_view description, int argc, char* argv[],
6465
opts::track_seeding seeding_opts;
6566
opts::track_finding finding_opts;
6667
opts::track_propagation propagation_opts;
68+
opts::track_resolution resolution_opts;
6769
opts::track_fitting fitting_opts;
6870
opts::throughput throughput_opts;
6971
opts::program_options program_opts{
7072
description,
7173
{detector_opts, bfield_opts, input_opts, clusterization_opts,
72-
seeding_opts, finding_opts, propagation_opts, fitting_opts,
73-
throughput_opts},
74+
seeding_opts, finding_opts, propagation_opts, resolution_opts,
75+
fitting_opts, throughput_opts},
7476
argc,
7577
argv,
7678
logger->cloneWithSuffix("Options")};
@@ -132,6 +134,9 @@ int throughput_st(std::string_view description, int argc, char* argv[],
132134
finding_opts);
133135
finding_cfg.propagation = propagation_config;
134136

137+
typename FULL_CHAIN_ALG::ambiguity_solving_algorithm::config_type
138+
resolution_cfg(resolution_opts);
139+
135140
typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg(
136141
fitting_opts);
137142
fitting_cfg.propagation = propagation_config;
@@ -140,7 +145,8 @@ int throughput_st(std::string_view description, int argc, char* argv[],
140145
std::unique_ptr<FULL_CHAIN_ALG> alg = std::make_unique<FULL_CHAIN_ALG>(
141146
alg_host_mr, clustering_cfg, seeding_opts.seedfinder,
142147
spacepoint_grid_config{seeding_opts.seedfinder},
143-
seeding_opts.seedfilter, finding_cfg, fitting_cfg, det_descr, field,
148+
seeding_opts.seedfilter, finding_cfg, resolution_cfg, fitting_cfg,
149+
det_descr, field,
144150
(detector_opts.use_detray_detector ? &detector : nullptr),
145151
logger->clone("FullChainAlg"));
146152

examples/run/cpu/full_chain_algorithm.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ full_chain_algorithm::full_chain_algorithm(
1616
const spacepoint_grid_config& grid_config,
1717
const seedfilter_config& filter_config,
1818
const finding_algorithm::config_type& finding_config,
19+
const ambiguity_solving_algorithm::config_type& resolution_config,
1920
const fitting_algorithm::config_type& fitting_config,
2021
const silicon_detector_description::host& det_descr, const bfield& field,
2122
detector_type* detector, std::unique_ptr<const traccc::Logger> logger)
@@ -32,6 +33,8 @@ full_chain_algorithm::full_chain_algorithm(
3233
m_track_parameter_estimation(mr,
3334
logger->cloneWithSuffix("TrackParamEstAlg")),
3435
m_finding(finding_config, mr, logger->cloneWithSuffix("TrackFindingAlg")),
36+
m_ambiguity_solving(resolution_config, mr,
37+
logger->cloneWithSuffix("AmbiguityResolutionAlg")),
3538
m_fitting(fitting_config, mr, *m_copy,
3639
logger->cloneWithSuffix("TrackFittingAlg")),
3740
m_finder_config(finder_config),
@@ -76,10 +79,14 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
7679
const finding_algorithm::output_type track_candidates = m_finding(
7780
*m_detector, m_field, measurements_view, track_params_view);
7881

82+
const ambiguity_solving_algorithm::output_type
83+
resolved_track_candidates = m_ambiguity_solving(
84+
{vecmem::get_data(track_candidates), measurements_view});
85+
7986
// Run the track fitting, and return its results.
8087
return m_fitting(
8188
*m_detector, m_field,
82-
{vecmem::get_data(track_candidates), measurements_view});
89+
{vecmem::get_data(resolved_track_candidates), measurements_view});
8390
}
8491
// If not, just return an empty object.
8592
else {

examples/run/cpu/full_chain_algorithm.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1112
#include "traccc/clusterization/clusterization_algorithm.hpp"
1213
#include "traccc/edm/silicon_cell_collection.hpp"
1314
#include "traccc/edm/track_state.hpp"
@@ -56,6 +57,9 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
5657
/// Track finding algorithm type
5758
using finding_algorithm =
5859
traccc::host::combinatorial_kalman_filter_algorithm;
60+
/// Ambiguity solving algorithm type
61+
using ambiguity_solving_algorithm =
62+
traccc::host::greedy_ambiguity_resolution_algorithm;
5963
/// Track fitting algorithm type
6064
using fitting_algorithm = traccc::host::kalman_fitting_algorithm;
6165

@@ -68,16 +72,18 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
6872
/// @param dummy This is not used anywhere. Allows templating CPU/Device
6973
/// algorithm.
7074
///
71-
full_chain_algorithm(vecmem::memory_resource& mr,
72-
const clustering_algorithm::config_type& dummy,
73-
const seedfinder_config& finder_config,
74-
const spacepoint_grid_config& grid_config,
75-
const seedfilter_config& filter_config,
76-
const finding_algorithm::config_type& finding_config,
77-
const fitting_algorithm::config_type& fitting_config,
78-
const silicon_detector_description::host& det_descr,
79-
const bfield& field, detector_type* detector,
80-
std::unique_ptr<const traccc::Logger> logger);
75+
full_chain_algorithm(
76+
vecmem::memory_resource& mr,
77+
const clustering_algorithm::config_type& dummy,
78+
const seedfinder_config& finder_config,
79+
const spacepoint_grid_config& grid_config,
80+
const seedfilter_config& filter_config,
81+
const finding_algorithm::config_type& finding_config,
82+
const ambiguity_solving_algorithm::config_type& resolution_config,
83+
const fitting_algorithm::config_type& fitting_config,
84+
const silicon_detector_description::host& det_descr,
85+
const bfield& field, detector_type* detector,
86+
std::unique_ptr<const traccc::Logger> logger);
8187

8288
/// Reconstruct track parameters in the entire detector
8389
///
@@ -115,6 +121,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
115121

116122
/// Track finding algorithm
117123
finding_algorithm m_finding;
124+
/// Ambiguity solving algorithm
125+
ambiguity_solving_algorithm m_ambiguity_solving;
118126
/// Track fitting algorithm
119127
fitting_algorithm m_fitting;
120128

@@ -132,6 +140,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
132140

133141
/// Configuration for the track finding
134142
finding_algorithm::config_type m_finding_config;
143+
/// Configuration for the ambiguity solving
144+
ambiguity_solving_algorithm::config_type m_resolution_config;
135145
/// Configuration for the track fitting
136146
fitting_algorithm::config_type m_fitting_config;
137147

examples/run/cuda/full_chain_algorithm.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ full_chain_algorithm::full_chain_algorithm(
3737
const spacepoint_grid_config& grid_config,
3838
const seedfilter_config& filter_config,
3939
const finding_algorithm::config_type& finding_config,
40+
const ambiguity_solving_algorithm::config_type& resolution_config,
4041
const fitting_algorithm::config_type& fitting_config,
4142
const silicon_detector_description::host& det_descr, const bfield& field,
4243
host_detector_type* detector, std::unique_ptr<const traccc::Logger> logger)
@@ -72,6 +73,9 @@ full_chain_algorithm::full_chain_algorithm(
7273
m_finding(finding_config,
7374
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
7475
m_stream, logger->cloneWithSuffix("TrackFindingAlg")),
76+
m_ambiguity_solving(
77+
resolution_config, memory_resource{*m_cached_device_mr, &m_host_mr},
78+
m_copy, m_stream, logger->cloneWithSuffix("AmbiguityResolutionAlg")),
7579
m_fitting(fitting_config,
7680
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
7781
m_stream, logger->cloneWithSuffix("TrackFittingAlg")),
@@ -80,6 +84,7 @@ full_chain_algorithm::full_chain_algorithm(
8084
m_grid_config(grid_config),
8185
m_filter_config(filter_config),
8286
m_finding_config(finding_config),
87+
m_resolution_config(resolution_config),
8388
m_fitting_config(fitting_config) {
8489

8590
// Tell the user what device is being used.
@@ -134,6 +139,10 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
134139
m_finding(parent.m_finding_config,
135140
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
136141
m_stream, parent.logger().cloneWithSuffix("TrackFindingAlg")),
142+
m_ambiguity_solving(
143+
parent.m_resolution_config,
144+
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream,
145+
parent.logger().cloneWithSuffix("AmbiguityResolutionAlg")),
137146
m_fitting(parent.m_fitting_config,
138147
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
139148
m_stream, parent.logger().cloneWithSuffix("TrackFittingAlg")),
@@ -142,6 +151,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
142151
m_grid_config(parent.m_grid_config),
143152
m_filter_config(parent.m_filter_config),
144153
m_finding_config(parent.m_finding_config),
154+
m_resolution_config(parent.m_resolution_config),
145155
m_fitting_config(parent.m_fitting_config) {
146156

147157
// Copy the detector (description) to the device.
@@ -182,9 +192,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
182192
const finding_algorithm::output_type track_candidates = m_finding(
183193
m_device_detector_view, m_field, measurements, track_params);
184194

195+
// Run the ambiguity solver (asynchronously).
196+
const ambiguity_solving_algorithm::output_type
197+
resolved_track_candidates =
198+
m_ambiguity_solving({track_candidates, measurements});
199+
185200
// Run the track fitting (asynchronously).
186-
const fitting_algorithm::output_type track_states = m_fitting(
187-
m_device_detector_view, m_field, {track_candidates, measurements});
201+
const fitting_algorithm::output_type track_states =
202+
m_fitting(m_device_detector_view, m_field,
203+
{resolved_track_candidates, measurements});
188204

189205
// Copy a limited amount of result data back to the host.
190206
output_type result{&m_host_mr};

examples/run/cuda/full_chain_algorithm.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
// Project include(s).
1111
#include "traccc/clusterization/clustering_config.hpp"
12+
#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
1213
#include "traccc/cuda/clusterization/clusterization_algorithm.hpp"
1314
#include "traccc/cuda/clusterization/measurement_sorting_algorithm.hpp"
1415
#include "traccc/cuda/finding/combinatorial_kalman_filter_algorithm.hpp"
@@ -66,6 +67,9 @@ class full_chain_algorithm
6667
/// Track finding algorithm type
6768
using finding_algorithm =
6869
traccc::cuda::combinatorial_kalman_filter_algorithm;
70+
/// Ambiguity solving algorithm type
71+
using ambiguity_solving_algorithm =
72+
traccc::cuda::greedy_ambiguity_resolution_algorithm;
6973
/// Track fitting algorithm type
7074
using fitting_algorithm = traccc::cuda::kalman_fitting_algorithm;
7175

@@ -76,16 +80,18 @@ class full_chain_algorithm
7680
/// @param mr The memory resource to use for the intermediate and result
7781
/// objects
7882
///
79-
full_chain_algorithm(vecmem::memory_resource& host_mr,
80-
const clustering_config& clustering_config,
81-
const seedfinder_config& finder_config,
82-
const spacepoint_grid_config& grid_config,
83-
const seedfilter_config& filter_config,
84-
const finding_algorithm::config_type& finding_config,
85-
const fitting_algorithm::config_type& fitting_config,
86-
const silicon_detector_description::host& det_descr,
87-
const bfield& field, host_detector_type* detector,
88-
std::unique_ptr<const traccc::Logger> logger);
83+
full_chain_algorithm(
84+
vecmem::memory_resource& host_mr,
85+
const clustering_config& clustering_config,
86+
const seedfinder_config& finder_config,
87+
const spacepoint_grid_config& grid_config,
88+
const seedfilter_config& filter_config,
89+
const finding_algorithm::config_type& finding_config,
90+
const ambiguity_solving_algorithm::config_type& resolution_config,
91+
const fitting_algorithm::config_type& fitting_config,
92+
const silicon_detector_description::host& det_descr,
93+
const bfield& field, host_detector_type* detector,
94+
std::unique_ptr<const traccc::Logger> logger);
8995

9096
/// Copy constructor
9197
///
@@ -153,6 +159,8 @@ class full_chain_algorithm
153159

154160
/// Track finding algorithm
155161
finding_algorithm m_finding;
162+
/// Ambiguity solving algorithm
163+
ambiguity_solving_algorithm m_ambiguity_solving;
156164
/// Track fitting algorithm
157165
fitting_algorithm m_fitting;
158166

@@ -172,6 +180,8 @@ class full_chain_algorithm
172180

173181
/// Configuration for the track finding
174182
finding_algorithm::config_type m_finding_config;
183+
/// Configuration for the ambiguity solving
184+
ambiguity_solving_algorithm::config_type m_resolution_config;
175185
/// Configuration for the track fitting
176186
fitting_algorithm::config_type m_fitting_config;
177187

0 commit comments

Comments
 (0)