Skip to content

Commit 9b98ec3

Browse files
committed
Adapted the examples to the SoA track fit EDM.
1 parent c728f7d commit 9b98ec3

19 files changed

+192
-177
lines changed

examples/run/alpaka/full_chain_algorithm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
182182
m_device_detector_view, m_field, {track_candidates, measurements});
183183

184184
// Copy a limited amount of result data back to the host.
185-
output_type result{&m_host_mr};
186-
m_vecmem_objects.async_copy()(track_states.headers, result)->wait();
185+
output_type result{m_host_mr};
186+
m_vecmem_objects.async_copy()(track_states.tracks, result)->wait();
187187
return result;
188188

189189
}
@@ -196,7 +196,7 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
196196
m_vecmem_objects.async_copy()(measurements, measurements_host)->wait();
197197

198198
// Return an empty object.
199-
return {};
199+
return output_type{m_host_mr};
200200
}
201201
}
202202

examples/run/alpaka/full_chain_algorithm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include "traccc/bfield/magnetic_field.hpp"
2121
#include "traccc/clusterization/clustering_config.hpp"
2222
#include "traccc/edm/silicon_cell_collection.hpp"
23+
#include "traccc/edm/track_fit_collection.hpp"
2324
#include "traccc/edm/track_parameters.hpp"
24-
#include "traccc/edm/track_state.hpp"
2525
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
2626
#include "traccc/geometry/detector.hpp"
2727
#include "traccc/geometry/silicon_detector_description.hpp"
@@ -44,7 +44,7 @@ namespace traccc::alpaka {
4444
/// At least as much as is implemented in the project at any given moment.
4545
///
4646
class full_chain_algorithm
47-
: public algorithm<vecmem::vector<fitting_result<default_algebra>>(
47+
: public algorithm<edm::track_fit_collection<default_algebra>::host(
4848
const edm::silicon_cell_collection::host&)>,
4949
public messaging {
5050

examples/run/alpaka/seeding_example_alpaka.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
131131
vecmem::copy& copy = vo.copy();
132132
vecmem::copy& async_copy = vo.async_copy();
133133

134-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
135-
track_state_d2h{mr, copy, logger().clone("TrackStateD2HCopyAlg")};
136-
137134
// Seeding algorithms
138135
const traccc::seedfinder_config seedfinder_config(seeding_opts);
139136
const traccc::seedfilter_config seedfilter_config(seeding_opts);
@@ -192,7 +189,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
192189
traccc::host::track_params_estimation::output_type params;
193190
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
194191
track_candidates{host_mr};
195-
traccc::track_state_container_types::host track_states;
192+
traccc::edm::track_fit_container<traccc::default_algebra>::host
193+
track_states{host_mr};
196194

197195
traccc::edm::seed_collection::buffer seeds_alpaka_buffer;
198196
traccc::bound_track_parameters_collection_types::buffer
@@ -201,8 +199,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
201199
traccc::edm::track_candidate_collection<traccc::default_algebra>::buffer
202200
track_candidates_alpaka_buffer;
203201

204-
traccc::track_state_container_types::buffer track_states_alpaka_buffer{
205-
{{}, *(mr.host)}, {{}, *(mr.host), mr.host}};
202+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
203+
track_states_alpaka_buffer;
206204

207205
{ // Start measuring wall time
208206
traccc::performance::timer wall_t("Wall time", elapsedTimes);
@@ -348,8 +346,14 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
348346
copy(track_candidates_alpaka_buffer, track_candidates_alpaka)->wait();
349347

350348
// Copy track states from device to host
351-
traccc::track_state_container_types::host track_states_alpaka =
352-
track_state_d2h(track_states_alpaka_buffer);
349+
traccc::edm::track_fit_container<traccc::default_algebra>::host
350+
track_states_alpaka{host_mr};
351+
async_copy(track_states_alpaka_buffer.tracks,
352+
track_states_alpaka.tracks)
353+
->wait();
354+
async_copy(track_states_alpaka_buffer.states,
355+
track_states_alpaka.states)
356+
->wait();
353357

354358
if (accelerator_opts.compare_with_cpu) {
355359
// Show which event we are currently presenting the results for.
@@ -396,8 +400,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
396400
n_seeds += seeds.size();
397401
n_found_tracks_alpaka += track_candidates_alpaka.size();
398402
n_found_tracks += track_candidates.size();
399-
n_fitted_tracks_alpaka += track_states_alpaka.size();
400-
n_fitted_tracks += track_states.size();
403+
n_fitted_tracks_alpaka += track_states_alpaka.tracks.size();
404+
n_fitted_tracks += track_states.tracks.size();
401405

402406
/*------------
403407
Writer

examples/run/alpaka/seq_example_alpaka.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ int seq_run(const traccc::opts::detector& detector_opts,
184184
device_fitting_algorithm fitting_alg_alpaka(
185185
fitting_cfg, mr, copy, queue, logger().clone("AlpakaFittingAlg"));
186186

187-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
188-
copy_track_states(mr, copy, logger().clone("TrackStateD2HCopyAlg"));
189-
190187
// performance writer
191188
traccc::seeding_performance_writer sd_performance_writer(
192189
traccc::seeding_performance_writer::config{},
@@ -206,7 +203,7 @@ int seq_run(const traccc::opts::detector& detector_opts,
206203
traccc::host::seeding_algorithm::output_type seeds{host_mr};
207204
traccc::host::track_params_estimation::output_type params{&host_mr};
208205
host_finding_algorithm::output_type track_candidates{host_mr};
209-
host_fitting_algorithm::output_type track_states;
206+
host_fitting_algorithm::output_type track_states{host_mr};
210207

211208
// Instantiate alpaka containers/collections
212209
traccc::measurement_collection_types::buffer measurements_alpaka_buffer(
@@ -217,7 +214,8 @@ int seq_run(const traccc::opts::detector& detector_opts,
217214
params_alpaka_buffer(0, *mr.host);
218215
traccc::edm::track_candidate_collection<traccc::default_algebra>::buffer
219216
track_candidates_buffer;
220-
traccc::track_state_container_types::buffer track_states_buffer;
217+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
218+
track_states_buffer;
221219

222220
{
223221
traccc::performance::timer wall_t("Wall time", elapsedTimes);
@@ -371,13 +369,16 @@ int seq_run(const traccc::opts::detector& detector_opts,
371369
&host_mr};
372370
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
373371
track_candidates_alpaka{host_mr};
372+
traccc::edm::track_fit_container<traccc::default_algebra>::host
373+
track_states_alpaka{host_mr};
374374

375375
copy(measurements_alpaka_buffer, measurements_per_event_alpaka)->wait();
376376
copy(spacepoints_alpaka_buffer, spacepoints_per_event_alpaka)->wait();
377377
copy(seeds_alpaka_buffer, seeds_alpaka)->wait();
378378
copy(params_alpaka_buffer, params_alpaka)->wait();
379379
copy(track_candidates_buffer, track_candidates_alpaka)->wait();
380-
auto track_states_alpaka = copy_track_states(track_states_buffer);
380+
copy(track_states_buffer.tracks, track_states_alpaka.tracks)->wait();
381+
copy(track_states_buffer.states, track_states_alpaka.states)->wait();
381382
queue.synchronize();
382383

383384
if (accelerator_opts.compare_with_cpu) {
@@ -429,12 +430,20 @@ int seq_run(const traccc::opts::detector& detector_opts,
429430
vecmem::get_data(track_candidates_alpaka));
430431

431432
// Compare tracks fitted on the host and on the device.
432-
traccc::collection_comparator<
433-
traccc::track_state_container_types::host::header_type>
434-
compare_track_states{"track states"};
435-
compare_track_states(
436-
vecmem::get_data(track_states.get_headers()),
437-
vecmem::get_data(track_states_alpaka.get_headers()));
433+
traccc::soa_comparator<
434+
traccc::edm::track_fit_collection<traccc::default_algebra>>
435+
compare_track_fits{
436+
"track fits",
437+
traccc::details::comparator_factory<
438+
traccc::edm::track_fit_collection<
439+
traccc::default_algebra>::const_device::
440+
const_proxy_type>{
441+
vecmem::get_data(measurements_per_event),
442+
vecmem::get_data(measurements_per_event_alpaka),
443+
vecmem::get_data(track_states.states),
444+
vecmem::get_data(track_states_alpaka.states)}};
445+
compare_track_fits(vecmem::get_data(track_states.tracks),
446+
vecmem::get_data(track_states_alpaka.tracks));
438447
}
439448
/// Statistics
440449
n_measurements += measurements_per_event.size();
@@ -445,8 +454,8 @@ int seq_run(const traccc::opts::detector& detector_opts,
445454
n_seeds_alpaka += seeds_alpaka.size();
446455
n_found_tracks += track_candidates.size();
447456
n_found_tracks_alpaka += track_candidates_alpaka.size();
448-
n_fitted_tracks += track_states.size();
449-
n_fitted_tracks_alpaka += track_states_alpaka.size();
457+
n_fitted_tracks += track_states.tracks.size();
458+
n_fitted_tracks_alpaka += track_states_alpaka.tracks.size();
450459

451460
if (performance_opts.run) {
452461

examples/run/cpu/full_chain_algorithm.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ full_chain_algorithm::full_chain_algorithm(
2121
const magnetic_field& field, detector_type* detector,
2222
std::unique_ptr<const traccc::Logger> logger)
2323
: messaging(logger->clone()),
24+
m_mr(mr),
2425
m_copy{std::make_unique<vecmem::copy>()},
2526
m_field_vec{0.f, 0.f, finder_config.bFieldInZ},
2627
m_field(field),
@@ -79,14 +80,15 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
7980

8081
// Run the track fitting, and return its results.
8182
return m_fitting(
82-
*m_detector, m_field,
83-
{vecmem::get_data(track_candidates), measurements_view});
83+
*m_detector, m_field,
84+
{vecmem::get_data(track_candidates), measurements_view})
85+
.tracks;
8486
}
8587
// If not, just return an empty object.
8688
else {
8789

8890
// Return an empty object.
89-
return {};
91+
return output_type{m_mr.get()};
9092
}
9193
}
9294

examples/run/cpu/full_chain_algorithm.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include "traccc/bfield/magnetic_field.hpp"
1212
#include "traccc/clusterization/clusterization_algorithm.hpp"
1313
#include "traccc/edm/silicon_cell_collection.hpp"
14+
#include "traccc/edm/track_fit_collection.hpp"
1415
#include "traccc/edm/track_parameters.hpp"
15-
#include "traccc/edm/track_state.hpp"
1616
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
1717
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
1818
#include "traccc/geometry/detector.hpp"
@@ -38,9 +38,10 @@ namespace traccc {
3838
///
3939
/// At least as much as is implemented in the project at any given moment.
4040
///
41-
class full_chain_algorithm : public algorithm<track_state_container_types::host(
42-
const edm::silicon_cell_collection::host&)>,
43-
public messaging {
41+
class full_chain_algorithm
42+
: public algorithm<edm::track_fit_collection<default_algebra>::host(
43+
const edm::silicon_cell_collection::host&)>,
44+
public messaging {
4445

4546
public:
4647
/// @name Type declaration(s)
@@ -97,6 +98,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
9798
const edm::silicon_cell_collection::host& cells) const;
9899

99100
private:
101+
/// Memory resource
102+
std::reference_wrapper<vecmem::memory_resource> m_mr;
100103
/// Vecmem copy object
101104
std::unique_ptr<vecmem::copy> m_copy;
102105
/// Constant B field for the (seed) track parameter estimation

examples/run/cpu/misaligned_truth_fitting_example.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,17 @@ int main(int argc, char* argv[]) {
166166
{vecmem::get_data(truth_track_candidates.tracks),
167167
vecmem::get_data(truth_track_candidates.measurements)});
168168

169-
print_fitted_tracks_statistics(track_states);
169+
// print_fitted_tracks_statistics(track_states);
170170

171-
const decltype(track_states)::size_type n_fitted_tracks =
172-
track_states.size();
171+
const std::size_t n_fitted_tracks = track_states.tracks.size();
173172

174173
if (performance_opts.run) {
175174

176175
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
177-
const auto& trk_states_per_track = track_states.at(i).items;
178-
179-
const auto& fit_res = track_states[i].header;
180-
181-
fit_performance_writer.write(trk_states_per_track, fit_res,
182-
host_det, evt_data,
183-
fit_cfg0.propagation.context);
176+
fit_performance_writer.write(
177+
track_states.tracks.at(i), track_states.states,
178+
truth_track_candidates.measurements, host_det, evt_data,
179+
fit_cfg0.propagation.context);
184180
}
185181
}
186182
} else {
@@ -195,21 +191,17 @@ int main(int argc, char* argv[]) {
195191
{vecmem::get_data(truth_track_candidates.tracks),
196192
vecmem::get_data(truth_track_candidates.measurements)});
197193

198-
print_fitted_tracks_statistics(track_states);
194+
// print_fitted_tracks_statistics(track_states);
199195

200-
const decltype(track_states)::size_type n_fitted_tracks =
201-
track_states.size();
196+
const std::size_t n_fitted_tracks = track_states.tracks.size();
202197

203198
if (performance_opts.run) {
204199

205200
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
206-
const auto& trk_states_per_track = track_states.at(i).items;
207-
208-
const auto& fit_res = track_states[i].header;
209-
210-
fit_performance_writer.write(trk_states_per_track, fit_res,
211-
host_det, evt_data,
212-
fit_cfg1.propagation.context);
201+
fit_performance_writer.write(
202+
track_states.tracks.at(i), track_states.states,
203+
truth_track_candidates.measurements, host_det, evt_data,
204+
fit_cfg1.propagation.context);
213205
}
214206
}
215207
}

examples/run/cpu/seeding_example.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
184184
track_candidates{host_mr};
185185
traccc::edm::track_candidate_collection<traccc::default_algebra>::host
186186
track_candidates_ar{host_mr};
187-
traccc::track_state_container_types::host track_states;
187+
traccc::edm::track_fit_container<traccc::default_algebra>::host
188+
track_states{host_mr};
188189

189190
/*------------------------
190191
Track Finding with CKF
@@ -211,7 +212,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
211212
track_states = host_fitting(detector, field,
212213
{vecmem::get_data(track_candidates_ar),
213214
vecmem::get_data(measurements_per_event)});
214-
n_fitted_tracks += track_states.size();
215+
n_fitted_tracks += track_states.tracks.size();
215216

216217
/*------------
217218
Statistics
@@ -243,13 +244,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
243244
vecmem::get_data(track_candidates_ar),
244245
vecmem::get_data(measurements_per_event), evt_data);
245246

246-
for (unsigned int i = 0; i < track_states.size(); i++) {
247-
const auto& trk_states_per_track = track_states.at(i).items;
248-
249-
const auto& fit_res = track_states[i].header;
250-
251-
fit_performance_writer.write(trk_states_per_track, fit_res,
252-
detector, evt_data);
247+
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
248+
fit_performance_writer.write(
249+
track_states.tracks.at(i), track_states.states,
250+
measurements_per_event, detector, evt_data);
253251
}
254252
}
255253
}

examples/run/cpu/seq_example.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ int seq_run(const traccc::opts::input_data& input_opts,
193193
finding_algorithm::output_type track_candidates{host_mr};
194194
traccc::host::greedy_ambiguity_resolution_algorithm::output_type
195195
resolved_track_candidates{host_mr};
196-
fitting_algorithm::output_type track_states{&host_mr};
196+
fitting_algorithm::output_type track_states{host_mr};
197197

198198
{ // Start measuring wall time.
199199
traccc::performance::timer timer_wall{"Wall time", elapsedTimes};
@@ -314,7 +314,7 @@ int seq_run(const traccc::opts::input_data& input_opts,
314314
n_seeds += seeds.size();
315315
n_found_tracks += track_candidates.size();
316316
n_ambiguity_free_tracks += resolved_track_candidates.size();
317-
n_fitted_tracks += track_states.size();
317+
n_fitted_tracks += track_states.tracks.size();
318318

319319
} // Stop measuring Wall time.
320320

@@ -341,13 +341,10 @@ int seq_run(const traccc::opts::input_data& input_opts,
341341
vecmem::get_data(resolved_track_candidates),
342342
vecmem::get_data(measurements_per_event), evt_data);
343343

344-
for (unsigned int i = 0; i < track_states.size(); i++) {
345-
const auto& trk_states_per_track = track_states.at(i).items;
346-
347-
const auto& fit_res = track_states[i].header;
348-
349-
fit_performance_writer.write(trk_states_per_track, fit_res,
350-
detector, evt_data);
344+
for (unsigned int i = 0; i < track_states.tracks.size(); i++) {
345+
fit_performance_writer.write(
346+
track_states.tracks.at(i), track_states.states,
347+
measurements_per_event, detector, evt_data);
351348
}
352349
}
353350
}

examples/run/cpu/truth_finding_example.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,19 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
160160
{vecmem::get_data(track_candidates),
161161
vecmem::get_data(measurements_per_event)});
162162

163-
print_fitted_tracks_statistics(track_states);
163+
// print_fitted_tracks_statistics(track_states);
164164

165-
const std::size_t n_fitted_tracks = track_states.size();
165+
const std::size_t n_fitted_tracks = track_states.tracks.size();
166166

167167
if (performance_opts.run) {
168168
find_performance_writer.write(
169169
vecmem::get_data(track_candidates),
170170
vecmem::get_data(measurements_per_event), evt_data);
171171

172172
for (std::size_t i = 0; i < n_fitted_tracks; i++) {
173-
const auto& trk_states_per_track = track_states.at(i).items;
174-
175-
const auto& fit_res = track_states[i].header;
176-
177-
fit_performance_writer.write(trk_states_per_track, fit_res,
178-
detector, evt_data);
173+
fit_performance_writer.write(
174+
track_states.tracks.at(i), track_states.states,
175+
measurements_per_event, detector, evt_data);
179176
}
180177
}
181178
}

0 commit comments

Comments
 (0)