Skip to content

Commit 4204d47

Browse files
committed
Adapted the examples to the SoA track fit EDM.
1 parent 9f7f77f commit 4204d47

19 files changed

+192
-177
lines changed

examples/run/alpaka/full_chain_algorithm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
182182
m_device_detector_view, m_field, {track_candidates, measurements});
183183

184184
// Copy a limited amount of result data back to the host.
185-
output_type result{&m_host_mr};
186-
m_vecmem_objects.async_copy()(track_states.headers, result)->wait();
185+
output_type result{m_host_mr};
186+
m_vecmem_objects.async_copy()(track_states.tracks, result)->wait();
187187
return result;
188188

189189
}
@@ -196,7 +196,7 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
196196
m_vecmem_objects.async_copy()(measurements, measurements_host)->wait();
197197

198198
// Return an empty object.
199-
return {};
199+
return output_type{m_host_mr};
200200
}
201201
}
202202

examples/run/alpaka/full_chain_algorithm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include "traccc/bfield/magnetic_field.hpp"
2121
#include "traccc/clusterization/clustering_config.hpp"
2222
#include "traccc/edm/silicon_cell_collection.hpp"
23+
#include "traccc/edm/track_fit_collection.hpp"
2324
#include "traccc/edm/track_parameters.hpp"
24-
#include "traccc/edm/track_state.hpp"
2525
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
2626
#include "traccc/geometry/detector.hpp"
2727
#include "traccc/geometry/silicon_detector_description.hpp"
@@ -44,7 +44,7 @@ namespace traccc::alpaka {
4444
/// At least as much as is implemented in the project at any given moment.
4545
///
4646
class full_chain_algorithm
47-
: public algorithm<vecmem::vector<fitting_result<default_algebra>>(
47+
: public algorithm<edm::track_fit_collection<default_algebra>::host(
4848
const edm::silicon_cell_collection::host&)>,
4949
public messaging {
5050

examples/run/alpaka/seeding_example_alpaka.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
137137
vecmem::copy& copy = vo.copy();
138138
vecmem::copy& async_copy = vo.async_copy();
139139

140-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
141-
track_state_d2h{mr, copy, logger().clone("TrackStateD2HCopyAlg")};
142-
143140
// Seeding algorithms
144141
const traccc::seedfinder_config seedfinder_config(seeding_opts);
145142
const traccc::seedfilter_config seedfilter_config(seeding_opts);
@@ -198,7 +195,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
198195
traccc::host::track_params_estimation::output_type params;
199196
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
200197
track_candidates{host_mr};
201-
traccc::track_state_container_types::host track_states;
198+
traccc::edm::track_fit_container<traccc::default_algebra>::host
199+
track_states{host_mr};
202200

203201
traccc::edm::seed_collection::buffer seeds_alpaka_buffer;
204202
traccc::bound_track_parameters_collection_types::buffer
@@ -207,8 +205,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
207205
traccc::edm::track_candidate_collection<traccc::default_algebra>::buffer
208206
track_candidates_alpaka_buffer;
209207

210-
traccc::track_state_container_types::buffer track_states_alpaka_buffer{
211-
{{}, *(mr.host)}, {{}, *(mr.host), mr.host}};
208+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
209+
track_states_alpaka_buffer;
212210

213211
{ // Start measuring wall time
214212
traccc::performance::timer wall_t("Wall time", elapsedTimes);
@@ -354,8 +352,14 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
354352
copy(track_candidates_alpaka_buffer, track_candidates_alpaka)->wait();
355353

356354
// Copy track states from device to host
357-
traccc::track_state_container_types::host track_states_alpaka =
358-
track_state_d2h(track_states_alpaka_buffer);
355+
traccc::edm::track_fit_container<traccc::default_algebra>::host
356+
track_states_alpaka{host_mr};
357+
async_copy(track_states_alpaka_buffer.tracks,
358+
track_states_alpaka.tracks)
359+
->wait();
360+
async_copy(track_states_alpaka_buffer.states,
361+
track_states_alpaka.states)
362+
->wait();
359363

360364
if (accelerator_opts.compare_with_cpu) {
361365
// Show which event we are currently presenting the results for.
@@ -402,8 +406,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
402406
n_seeds += seeds.size();
403407
n_found_tracks_alpaka += track_candidates_alpaka.size();
404408
n_found_tracks += track_candidates.size();
405-
n_fitted_tracks_alpaka += track_states_alpaka.size();
406-
n_fitted_tracks += track_states.size();
409+
n_fitted_tracks_alpaka += track_states_alpaka.tracks.size();
410+
n_fitted_tracks += track_states.tracks.size();
407411

408412
/*------------
409413
Writer

examples/run/alpaka/seq_example_alpaka.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ int seq_run(const traccc::opts::detector& detector_opts,
184184
device_fitting_algorithm fitting_alg_alpaka(
185185
fitting_cfg, mr, copy, queue, logger().clone("AlpakaFittingAlg"));
186186

187-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
188-
copy_track_states(mr, copy, logger().clone("TrackStateD2HCopyAlg"));
189-
190187
// performance writer
191188
traccc::seeding_performance_writer sd_performance_writer(
192189
traccc::seeding_performance_writer::config{},
@@ -206,7 +203,7 @@ int seq_run(const traccc::opts::detector& detector_opts,
206203
traccc::host::seeding_algorithm::output_type seeds{host_mr};
207204
traccc::host::track_params_estimation::output_type params{&host_mr};
208205
host_finding_algorithm::output_type track_candidates{host_mr};
209-
host_fitting_algorithm::output_type track_states;
206+
host_fitting_algorithm::output_type track_states{host_mr};
210207

211208
// Instantiate alpaka containers/collections
212209
traccc::measurement_collection_types::buffer measurements_alpaka_buffer(
@@ -217,7 +214,8 @@ int seq_run(const traccc::opts::detector& detector_opts,
217214
params_alpaka_buffer(0, *mr.host);
218215
traccc::edm::track_candidate_collection<traccc::default_algebra>::buffer
219216
track_candidates_buffer;
220-
traccc::track_state_container_types::buffer track_states_buffer;
217+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
218+
track_states_buffer;
221219

222220
{
223221
traccc::performance::timer wall_t("Wall time", elapsedTimes);
@@ -371,13 +369,16 @@ int seq_run(const traccc::opts::detector& detector_opts,
371369
&host_mr};
372370
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
373371
track_candidates_alpaka{host_mr};
372+
traccc::edm::track_fit_container<traccc::default_algebra>::host
373+
track_states_alpaka{host_mr};
374374

375375
copy(measurements_alpaka_buffer, measurements_per_event_alpaka)->wait();
376376
copy(spacepoints_alpaka_buffer, spacepoints_per_event_alpaka)->wait();
377377
copy(seeds_alpaka_buffer, seeds_alpaka)->wait();
378378
copy(params_alpaka_buffer, params_alpaka)->wait();
379379
copy(track_candidates_buffer, track_candidates_alpaka)->wait();
380-
auto track_states_alpaka = copy_track_states(track_states_buffer);
380+
copy(track_states_buffer.tracks, track_states_alpaka.tracks)->wait();
381+
copy(track_states_buffer.states, track_states_alpaka.states)->wait();
381382
queue.synchronize();
382383

383384
if (accelerator_opts.compare_with_cpu) {
@@ -429,12 +430,20 @@ int seq_run(const traccc::opts::detector& detector_opts,
429430
vecmem::get_data(track_candidates_alpaka));
430431

431432
// Compare tracks fitted on the host and on the device.
432-
traccc::collection_comparator<
433-
traccc::track_state_container_types::host::header_type>
434-
compare_track_states{"track states"};
435-
compare_track_states(
436-
vecmem::get_data(track_states.get_headers()),
437-
vecmem::get_data(track_states_alpaka.get_headers()));
433+
traccc::soa_comparator<
434+
traccc::edm::track_fit_collection<traccc::default_algebra>>
435+
compare_track_fits{
436+
"track fits",
437+
traccc::details::comparator_factory<
438+
traccc::edm::track_fit_collection<
439+
traccc::default_algebra>::const_device::
440+
const_proxy_type>{
441+
vecmem::get_data(measurements_per_event),
442+
vecmem::get_data(measurements_per_event_alpaka),
443+
vecmem::get_data(track_states.states),
444+
vecmem::get_data(track_states_alpaka.states)}};
445+
compare_track_fits(vecmem::get_data(track_states.tracks),
446+
vecmem::get_data(track_states_alpaka.tracks));
438447
}
439448
/// Statistics
440449
n_measurements += measurements_per_event.size();
@@ -445,8 +454,8 @@ int seq_run(const traccc::opts::detector& detector_opts,
445454
n_seeds_alpaka += seeds_alpaka.size();
446455
n_found_tracks += track_candidates.size();
447456
n_found_tracks_alpaka += track_candidates_alpaka.size();
448-
n_fitted_tracks += track_states.size();
449-
n_fitted_tracks_alpaka += track_states_alpaka.size();
457+
n_fitted_tracks += track_states.tracks.size();
458+
n_fitted_tracks_alpaka += track_states_alpaka.tracks.size();
450459

451460
if (performance_opts.run) {
452461

examples/run/cpu/full_chain_algorithm.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ full_chain_algorithm::full_chain_algorithm(
2121
const magnetic_field& field, detector_type* detector,
2222
std::unique_ptr<const traccc::Logger> logger)
2323
: messaging(logger->clone()),
24+
m_mr(mr),
2425
m_copy{std::make_unique<vecmem::copy>()},
2526
m_field_vec{0.f, 0.f, finder_config.bFieldInZ},
2627
m_field(field),
@@ -79,14 +80,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
7980

8081
// Run the track fitting, and return its results.
8182
return m_fitting(
82-
*m_detector, m_field,
83-
{vecmem::get_data(track_candidates), measurements_view});
83+
*m_detector, m_field,
84+
{vecmem::get_data(track_candidates), measurements_view})
85+
.tracks;
8486
}
8587
// If not, just return an empty object.
8688
else {
8789

8890
// Return an empty object.
89-
return {};
91+
return output_type{m_mr.get()};
9092
}
9193
}
9294

examples/run/cpu/full_chain_algorithm.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include "traccc/bfield/magnetic_field.hpp"
1212
#include "traccc/clusterization/clusterization_algorithm.hpp"
1313
#include "traccc/edm/silicon_cell_collection.hpp"
14+
#include "traccc/edm/track_fit_collection.hpp"
1415
#include "traccc/edm/track_parameters.hpp"
15-
#include "traccc/edm/track_state.hpp"
1616
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
1717
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
1818
#include "traccc/geometry/detector.hpp"
@@ -38,9 +38,10 @@ namespace traccc {
3838
///
3939
/// At least as much as is implemented in the project at any given moment.
4040
///
41-
class full_chain_algorithm : public algorithm<track_state_container_types::host(
42-
const edm::silicon_cell_collection::host&)>,
43-
public messaging {
41+
class full_chain_algorithm
42+
: public algorithm<edm::track_fit_collection<default_algebra>::host(
43+
const edm::silicon_cell_collection::host&)>,
44+
public messaging {
4445

4546
public:
4647
/// @name Type declaration(s)
@@ -97,6 +98,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
9798
const edm::silicon_cell_collection::host& cells) const;
9899

99100
private:
101+
/// Memory resource
102+
std::reference_wrapper<vecmem::memory_resource> m_mr;
100103
/// Vecmem copy object
101104
std::unique_ptr<vecmem::copy> m_copy;
102105
/// Constant B field for the (seed) track parameter estimation

examples/run/cpu/misaligned_truth_fitting_example.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,17 @@ int main(int argc, char* argv[]) {
166166
{vecmem::get_data(truth_track_candidates.tracks),
167167
vecmem::get_data(truth_track_candidates.measurements)});
168168

169-
print_fitted_tracks_statistics(track_states);
169+
// print_fitted_tracks_statistics(track_states);
170170

171-
const decltype(track_states)::size_type n_fitted_tracks =
172-
track_states.size();
171+
const std::size_t n_fitted_tracks = track_states.tracks.size();
173172

174173
if (performance_opts.run) {
175174

176175
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
177-
const auto& trk_states_per_track = track_states.at(i).items;
178-
179-
const auto& fit_res = track_states[i].header;
180-
181-
fit_performance_writer.write(trk_states_per_track, fit_res,
182-
host_det, evt_data,
183-
fit_cfg0.propagation.context);
176+
fit_performance_writer.write(
177+
track_states.tracks.at(i), track_states.states,
178+
truth_track_candidates.measurements, host_det, evt_data,
179+
fit_cfg0.propagation.context);
184180
}
185181
}
186182
} else {
@@ -195,21 +191,17 @@ int main(int argc, char* argv[]) {
195191
{vecmem::get_data(truth_track_candidates.tracks),
196192
vecmem::get_data(truth_track_candidates.measurements)});
197193

198-
print_fitted_tracks_statistics(track_states);
194+
// print_fitted_tracks_statistics(track_states);
199195

200-
const decltype(track_states)::size_type n_fitted_tracks =
201-
track_states.size();
196+
const std::size_t n_fitted_tracks = track_states.tracks.size();
202197

203198
if (performance_opts.run) {
204199

205200
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
206-
const auto& trk_states_per_track = track_states.at(i).items;
207-
208-
const auto& fit_res = track_states[i].header;
209-
210-
fit_performance_writer.write(trk_states_per_track, fit_res,
211-
host_det, evt_data,
212-
fit_cfg1.propagation.context);
201+
fit_performance_writer.write(
202+
track_states.tracks.at(i), track_states.states,
203+
truth_track_candidates.measurements, host_det, evt_data,
204+
fit_cfg1.propagation.context);
213205
}
214206
}
215207
}

examples/run/cpu/seeding_example.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
190190
track_candidates{host_mr};
191191
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
192192
track_candidates_ar{host_mr};
193-
traccc::track_state_container_types::host track_states;
193+
traccc::edm::track_fit_container<traccc::default_algebra>::host
194+
track_states{host_mr};
194195

195196
/*------------------------
196197
Track Finding with CKF
@@ -217,7 +218,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
217218
track_states = host_fitting(detector, field,
218219
{vecmem::get_data(track_candidates_ar),
219220
vecmem::get_data(measurements_per_event)});
220-
n_fitted_tracks += track_states.size();
221+
n_fitted_tracks += track_states.tracks.size();
221222

222223
/*------------
223224
Statistics
@@ -249,13 +250,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
249250
vecmem::get_data(track_candidates_ar),
250251
vecmem::get_data(measurements_per_event), evt_data);
251252

252-
for (unsigned int i = 0; i < track_states.size(); i++) {
253-
const auto& trk_states_per_track = track_states.at(i).items;
254-
255-
const auto& fit_res = track_states[i].header;
256-
257-
fit_performance_writer.write(trk_states_per_track, fit_res,
258-
detector, evt_data);
253+
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
254+
fit_performance_writer.write(
255+
track_states.tracks.at(i), track_states.states,
256+
measurements_per_event, detector, evt_data);
259257
}
260258
}
261259
}

examples/run/cpu/seq_example.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ int seq_run(const traccc::opts::input_data& input_opts,
199199
finding_algorithm::output_type track_candidates{host_mr};
200200
traccc::host::greedy_ambiguity_resolution_algorithm::output_type
201201
resolved_track_candidates{host_mr};
202-
fitting_algorithm::output_type track_states{&host_mr};
202+
fitting_algorithm::output_type track_states{host_mr};
203203

204204
{ // Start measuring wall time.
205205
traccc::performance::timer timer_wall{"Wall time", elapsedTimes};
@@ -320,7 +320,7 @@ int seq_run(const traccc::opts::input_data& input_opts,
320320
n_seeds += seeds.size();
321321
n_found_tracks += track_candidates.size();
322322
n_ambiguity_free_tracks += resolved_track_candidates.size();
323-
n_fitted_tracks += track_states.size();
323+
n_fitted_tracks += track_states.tracks.size();
324324

325325
} // Stop measuring Wall time.
326326

@@ -347,13 +347,10 @@ int seq_run(const traccc::opts::input_data& input_opts,
347347
vecmem::get_data(resolved_track_candidates),
348348
vecmem::get_data(measurements_per_event), evt_data);
349349

350-
for (unsigned int i = 0; i < track_states.size(); i++) {
351-
const auto& trk_states_per_track = track_states.at(i).items;
352-
353-
const auto& fit_res = track_states[i].header;
354-
355-
fit_performance_writer.write(trk_states_per_track, fit_res,
356-
detector, evt_data);
350+
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
351+
fit_performance_writer.write(
352+
track_states.tracks.at(i), track_states.states,
353+
measurements_per_event, detector, evt_data);
357354
}
358355
}
359356
}

examples/run/cpu/truth_finding_example.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,22 +163,19 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
163163
{vecmem::get_data(track_candidates),
164164
vecmem::get_data(measurements_per_event)});
165165

166-
print_fitted_tracks_statistics(track_states);
166+
// print_fitted_tracks_statistics(track_states);
167167

168-
const std::size_t n_fitted_tracks = track_states.size();
168+
const std::size_t n_fitted_tracks = track_states.tracks.size();
169169

170170
if (performance_opts.run) {
171171
find_performance_writer.write(
172172
vecmem::get_data(track_candidates),
173173
vecmem::get_data(measurements_per_event), evt_data);
174174

175175
for (std::size_t i = 0; i < n_fitted_tracks; i++) {
176-
const auto& trk_states_per_track = track_states.at(i).items;
177-
178-
const auto& fit_res = track_states[i].header;
179-
180-
fit_performance_writer.write(trk_states_per_track, fit_res,
181-
detector, evt_data);
176+
fit_performance_writer.write(
177+
track_states.tracks.at(i), track_states.states,
178+
measurements_per_event, detector, evt_data);
182179
}
183180
}
184181
}

0 commit comments

Comments
 (0)