Skip to content

Commit a4a6900

Browse files
committed
Introduced track_candidates_container::device and track_fit_container::device.
To make it a bit easier to deal with the slightly complex EDMs of track finding and fitting in the code.
1 parent 4204d47 commit a4a6900

File tree

11 files changed

+115
-87
lines changed

11 files changed

+115
-87
lines changed

core/include/traccc/edm/track_candidate_container.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Local include(s).
11+
#include "traccc/definitions/qualifiers.hpp"
1112
#include "traccc/edm/measurement.hpp"
1213
#include "traccc/edm/track_candidate_collection.hpp"
1314

@@ -71,6 +72,28 @@ struct track_candidate_container {
7172
measurement_collection_types::const_view measurements;
7273
};
7374

75+
struct device {
76+
/// Constructor from a view
77+
TRACCC_HOST_DEVICE
78+
explicit device(const view& v)
79+
: tracks{v.tracks}, measurements{v.measurements} {}
80+
/// The track candidates
81+
track_candidate_collection<ALGEBRA>::device tracks;
82+
/// Measurements referenced by the tracks
83+
measurement_collection_types::const_device measurements;
84+
};
85+
86+
struct const_device {
87+
/// Constructor from a view
88+
TRACCC_HOST_DEVICE
89+
explicit const_device(const const_view& v)
90+
: tracks{v.tracks}, measurements{v.measurements} {}
91+
/// The track candidates
92+
track_candidate_collection<ALGEBRA>::const_device tracks;
93+
/// Measurements referenced by the tracks
94+
measurement_collection_types::const_device measurements;
95+
};
96+
7497
}; // struct track_candidate_container
7598

7699
} // namespace traccc::edm

core/include/traccc/edm/track_fit_container.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Local include(s).
11+
#include "traccc/definitions/qualifiers.hpp"
1112
#include "traccc/edm/measurement.hpp"
1213
#include "traccc/edm/track_fit_collection.hpp"
1314
#include "traccc/edm/track_state_collection.hpp"
@@ -70,6 +71,36 @@ struct track_fit_container {
7071
track_state_collection<ALGEBRA>::const_data states;
7172
};
7273

74+
struct device {
75+
/// Constructor from a view
76+
TRACCC_HOST_DEVICE
77+
explicit device(const view& v)
78+
: tracks{v.tracks},
79+
states{v.states},
80+
measurements{v.measurements} {}
81+
/// The fitted tracks
82+
track_fit_collection<ALGEBRA>::device tracks;
83+
/// The track states used for the fit
84+
track_state_collection<ALGEBRA>::device states;
85+
/// The measurements used for the fit
86+
measurement_collection_types::const_device measurements;
87+
};
88+
89+
struct const_device {
90+
/// Constructor from a view
91+
TRACCC_HOST_DEVICE
92+
explicit const_device(const const_view& v)
93+
: tracks{v.tracks},
94+
states{v.states},
95+
measurements{v.measurements} {}
96+
/// The fitted tracks
97+
track_fit_collection<ALGEBRA>::const_device tracks;
98+
/// The track states used for the fit
99+
track_state_collection<ALGEBRA>::const_device states;
100+
/// The measurements used for the fit
101+
measurement_collection_types::const_device measurements;
102+
};
103+
73104
}; // struct track_fit_container
74105

75106
} // namespace traccc::edm

device/common/include/traccc/fitting/device/fit_backward.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,23 @@ TRACCC_HOST_DEVICE inline void fit_backward(
2222
vecmem::device_vector<const unsigned int> param_ids(payload.param_ids_view);
2323
vecmem::device_vector<unsigned int> param_liveness(
2424
payload.param_liveness_view);
25-
typename edm::track_fit_collection<
25+
typename edm::track_fit_container<
2626
typename fitter_t::detector_type::algebra_type>::device
27-
tracks(payload.tracks_view.tracks);
28-
typename edm::track_state_collection<
29-
typename fitter_t::detector_type::algebra_type>::device
30-
track_states(payload.tracks_view.states);
31-
measurement_collection_types::const_device measurements{
32-
payload.tracks_view.measurements};
27+
tracks(payload.tracks_view);
3328

34-
if (globalIndex >= tracks.size()) {
29+
if (globalIndex >= tracks.tracks.size()) {
3530
return;
3631
}
3732

3833
const unsigned int param_id = param_ids.at(globalIndex);
39-
auto track = tracks.at(param_id);
34+
auto track = tracks.tracks.at(param_id);
4035

4136
// Run fitting
4237
fitter_t fitter(det, payload.field_data, cfg);
4338

4439
if (param_liveness.at(param_id) > 0u) {
4540
typename fitter_t::state fitter_state(
46-
track, track_states, measurements,
41+
track, tracks.states, tracks.measurements,
4742
*(payload.barcodes_view.ptr() + param_id));
4843

4944
kalman_fitter_status fit_status = fitter.smooth(fitter_state);

device/common/include/traccc/fitting/device/fit_forward.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,26 @@ TRACCC_HOST_DEVICE inline void fit_forward(
2222
vecmem::device_vector<const unsigned int> param_ids(payload.param_ids_view);
2323
vecmem::device_vector<unsigned int> param_liveness(
2424
payload.param_liveness_view);
25-
typename edm::track_fit_collection<
25+
typename edm::track_fit_container<
2626
typename fitter_t::detector_type::algebra_type>::device
27-
tracks(payload.tracks_view.tracks);
28-
typename edm::track_state_collection<
29-
typename fitter_t::detector_type::algebra_type>::device
30-
track_states(payload.tracks_view.states);
31-
measurement_collection_types::const_device measurements{
32-
payload.tracks_view.measurements};
27+
tracks(payload.tracks_view);
3328

34-
if (globalIndex >= tracks.size()) {
29+
if (globalIndex >= tracks.tracks.size()) {
3530
return;
3631
}
3732

3833
const unsigned int param_id = param_ids.at(globalIndex);
3934

4035
fitter_t fitter(det, payload.field_data, cfg);
4136

42-
auto track = tracks.at(param_id);
37+
auto track = tracks.tracks.at(param_id);
4338
auto params = track.params();
4439

4540
// TODO: Merge into filter?
4641
inflate_covariance(params, fitter.config().covariance_inflation_factor);
4742

4843
typename fitter_t::state fitter_state(
49-
track, track_states, measurements,
44+
track, tracks.states, tracks.measurements,
5045
*(payload.barcodes_view.ptr() + param_id));
5146

5247
kalman_fitter_status fit_status = fitter.filter(params, fitter_state);

examples/run/cpu/seeding_example.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,14 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
243243
vecmem::get_data(measurements_per_event), evt_data);
244244

245245
find_performance_writer.write(
246-
vecmem::get_data(track_candidates),
247-
vecmem::get_data(measurements_per_event), evt_data);
246+
{vecmem::get_data(track_candidates),
247+
vecmem::get_data(measurements_per_event)},
248+
evt_data);
248249

249250
ar_performance_writer.write(
250-
vecmem::get_data(track_candidates_ar),
251-
vecmem::get_data(measurements_per_event), evt_data);
251+
{vecmem::get_data(track_candidates_ar),
252+
vecmem::get_data(measurements_per_event)},
253+
evt_data);
252254

253255
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
254256
fit_performance_writer.write(

examples/run/cpu/seq_example.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,13 @@ int seq_run(const traccc::opts::input_data& input_opts,
341341
vecmem::get_data(spacepoints_per_event),
342342
vecmem::get_data(measurements_per_event), evt_data);
343343
find_performance_writer.write(
344-
vecmem::get_data(track_candidates),
345-
vecmem::get_data(measurements_per_event), evt_data);
344+
{vecmem::get_data(track_candidates),
345+
vecmem::get_data(measurements_per_event)},
346+
evt_data);
346347
ar_performance_writer.write(
347-
vecmem::get_data(resolved_track_candidates),
348-
vecmem::get_data(measurements_per_event), evt_data);
348+
{vecmem::get_data(resolved_track_candidates),
349+
vecmem::get_data(measurements_per_event)},
350+
evt_data);
349351

350352
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
351353
fit_performance_writer.write(

examples/run/cpu/truth_finding_example.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
169169

170170
if (performance_opts.run) {
171171
find_performance_writer.write(
172-
vecmem::get_data(track_candidates),
173-
vecmem::get_data(measurements_per_event), evt_data);
172+
{vecmem::get_data(track_candidates),
173+
vecmem::get_data(measurements_per_event)},
174+
evt_data);
174175

175176
for (std::size_t i = 0; i < n_fitted_tracks; i++) {
176177
fit_performance_writer.write(

examples/run/cuda/seeding_example_cuda.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,9 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
432432
vecmem::get_data(measurements_per_event), evt_data);
433433

434434
find_performance_writer.write(
435-
vecmem::get_data(track_candidates_cuda),
436-
vecmem::get_data(measurements_per_event), evt_data);
435+
{vecmem::get_data(track_candidates_cuda),
436+
vecmem::get_data(measurements_per_event)},
437+
evt_data);
437438

438439
for (unsigned int i = 0; i < track_states_cuda.tracks.size(); i++) {
439440
fit_performance_writer.write(

examples/run/cuda/truth_finding_example_cuda.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
299299

300300
if (performance_opts.run) {
301301
find_performance_writer.write(
302-
vecmem::get_data(track_candidates_cuda),
303-
vecmem::get_data(measurements_per_event), evt_data);
302+
{vecmem::get_data(track_candidates_cuda),
303+
vecmem::get_data(measurements_per_event)},
304+
evt_data);
304305

305306
for (unsigned int i = 0; i < track_states_cuda.tracks.size(); i++) {
306307
fit_performance_writer.write(

performance/include/traccc/efficiency/finding_performance_writer.hpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
#include "traccc/utils/truth_matching_config.hpp"
1616

1717
// Project include(s).
18-
#include "traccc/edm/measurement.hpp"
19-
#include "traccc/edm/track_candidate_collection.hpp"
20-
#include "traccc/edm/track_fit_collection.hpp"
21-
#include "traccc/edm/track_state_collection.hpp"
18+
#include "traccc/edm/track_candidate_container.hpp"
19+
#include "traccc/edm/track_fit_container.hpp"
2220
#include "traccc/utils/event_data.hpp"
2321

2422
// System include(s).
@@ -71,19 +69,13 @@ class finding_performance_writer : public messaging {
7169
/// Destructor
7270
~finding_performance_writer();
7371

74-
void write(
75-
const edm::track_candidate_collection<default_algebra>::const_view&
76-
track_candidates_view,
77-
const measurement_collection_types::const_view& measurements_view,
78-
const event_data& evt_data);
79-
80-
void write(
81-
const edm::track_fit_collection<default_algebra>::const_view&
82-
track_fit_view,
83-
const edm::track_state_collection<default_algebra>::const_view&
84-
track_states_view,
85-
const measurement_collection_types::const_view& measurements_view,
86-
const event_data& evt_data);
72+
void write(const edm::track_candidate_container<
73+
default_algebra>::const_view& track_candidates_view,
74+
const event_data& evt_data);
75+
76+
void write(const edm::track_fit_container<default_algebra>::const_view&
77+
track_fit_view,
78+
const event_data& evt_data);
8779

8880
void finalize();
8981

0 commit comments

Comments
 (0)