Skip to content

Commit 2ee7c61

Browse files
committed
Adapted traccc::device_common and traccc::cuda to the SoA track fit EDM.
1 parent c4fdf68 commit 2ee7c61

File tree

9 files changed

+111
-56
lines changed

9 files changed

+111
-56
lines changed

device/common/include/traccc/finding/device/impl/find_tracks.ipp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ TRACCC_HOST_DEVICE inline void find_tracks(
178178

179179
barrier.blockBarrier();
180180

181-
std::optional<std::tuple<track_state<typename detector_t::algebra_type>,
182-
unsigned int, unsigned int>>
181+
std::optional<std::tuple<
182+
typename edm::track_state_collection<
183+
typename detector_t::algebra_type>::device::object_type,
184+
unsigned int, unsigned int>>
183185
result = std::nullopt;
184186

185187
/*
@@ -216,17 +218,21 @@ TRACCC_HOST_DEVICE inline void find_tracks(
216218
}
217219

218220
if (use_measurement) {
219-
const auto& meas = measurements.at(meas_idx);
220221

221-
track_state<typename detector_t::algebra_type> trk_state(meas);
222+
typename edm::track_state_collection<
223+
typename detector_t::algebra_type>::device::object_type
224+
trk_state(0u, 0.f, 0.f, 0.f, {}, {}, meas_idx);
225+
trk_state.filtered_params().set_surface_link(
226+
measurements.at(meas_idx).surface_link);
227+
222228
const detray::tracking_surface sf{det, in_par.surface_link()};
223229

224230
const bool is_line = sf.template visit_mask<is_line_visitor>();
225231

226232
// Run the Kalman update
227233
const kalman_fitter_status res =
228234
gain_matrix_updater<typename detector_t::algebra_type>{}(
229-
trk_state, in_par, is_line);
235+
trk_state, measurements, in_par, is_line);
230236

231237
/*
232238
* The $\chi^2$ value from the Kalman update should be less than
@@ -246,8 +252,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
246252
}
247253

248254
if (use_measurement) {
249-
result.emplace(std::move(trk_state), meas_idx,
250-
owner_local_thread_id);
255+
result.emplace(trk_state, meas_idx, owner_local_thread_id);
251256
}
252257
}
253258
}
@@ -441,10 +446,12 @@ TRACCC_HOST_DEVICE inline void find_tracks(
441446
.chi2_sum = prev_chi2_sum + chi2,
442447
.ndf_sum =
443448
prev_ndf_sum +
444-
std::get<0>(*result).get_measurement().meas_dim};
449+
measurements
450+
.at(std::get<0>(*result).measurement_index())
451+
.meas_dim};
445452

446453
tmp_params.at(p_offset + l_pos) =
447-
std::get<0>(*result).filtered();
454+
std::get<0>(*result).filtered_params();
448455

449456
/*
450457
* Reset the temporary state storage, as this is no longer

device/common/include/traccc/fitting/device/fit.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
// Project include(s).
1414
#include "traccc/definitions/qualifiers.hpp"
15-
#include "traccc/edm/track_state.hpp"
15+
#include "traccc/edm/track_fit_container.hpp"
1616

1717
// VecMem include(s).
1818
#include <vecmem/containers/data/jagged_vector_view.hpp>
@@ -44,9 +44,10 @@ struct fit_payload {
4444
vecmem::data::vector_view<unsigned int> param_liveness_view;
4545

4646
/**
47-
* @brief View object to the output track states
47+
* @brief View object to the output tracks
4848
*/
49-
track_state_container_types::view track_states_view;
49+
typename edm::track_fit_container<
50+
typename fitter_t::detector_type::algebra_type>::view tracks_view;
5051

5152
/**
5253
* @brief View object to the output barcode sequence

device/common/include/traccc/fitting/device/fit_backward.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,28 @@ TRACCC_HOST_DEVICE inline void fit_backward(
2222
vecmem::device_vector<const unsigned int> param_ids(payload.param_ids_view);
2323
vecmem::device_vector<unsigned int> param_liveness(
2424
payload.param_liveness_view);
25-
track_state_container_types::device track_states(payload.track_states_view);
26-
27-
if (globalIndex >= track_states.size()) {
25+
typename edm::track_fit_collection<
26+
typename fitter_t::detector_type::algebra_type>::device
27+
tracks(payload.tracks_view.tracks);
28+
typename edm::track_state_collection<
29+
typename fitter_t::detector_type::algebra_type>::device
30+
track_states(payload.tracks_view.states);
31+
measurement_collection_types::const_device measurements{
32+
payload.tracks_view.measurements};
33+
34+
if (globalIndex >= tracks.size()) {
2835
return;
2936
}
3037

3138
const unsigned int param_id = param_ids.at(globalIndex);
39+
auto track = tracks.at(param_id);
3240

3341
// Run fitting
3442
fitter_t fitter(det, payload.field_data, cfg);
3543

3644
if (param_liveness.at(param_id) > 0u) {
3745
typename fitter_t::state fitter_state(
38-
track_states.at(param_id).items,
46+
track, track_states, measurements,
3947
*(payload.barcodes_view.ptr() + param_id));
4048

4149
kalman_fitter_status fit_status = fitter.smooth(fitter_state);
@@ -48,7 +56,7 @@ TRACCC_HOST_DEVICE inline void fit_backward(
4856

4957
assert(fit_status == kalman_fitter_status::SUCCESS);
5058

51-
track_states.at(param_id).header = fitter_state.m_fit_res;
59+
track = fitter_state.m_fit_res;
5260
} else {
5361
param_liveness.at(param_id) = 0u;
5462
}

device/common/include/traccc/fitting/device/fit_forward.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,31 @@ TRACCC_HOST_DEVICE inline void fit_forward(
2222
vecmem::device_vector<const unsigned int> param_ids(payload.param_ids_view);
2323
vecmem::device_vector<unsigned int> param_liveness(
2424
payload.param_liveness_view);
25-
track_state_container_types::device track_states(payload.track_states_view);
26-
27-
if (globalIndex >= track_states.size()) {
25+
typename edm::track_fit_collection<
26+
typename fitter_t::detector_type::algebra_type>::device
27+
tracks(payload.tracks_view.tracks);
28+
typename edm::track_state_collection<
29+
typename fitter_t::detector_type::algebra_type>::device
30+
track_states(payload.tracks_view.states);
31+
measurement_collection_types::const_device measurements{
32+
payload.tracks_view.measurements};
33+
34+
if (globalIndex >= tracks.size()) {
2835
return;
2936
}
3037

3138
const unsigned int param_id = param_ids.at(globalIndex);
3239

3340
fitter_t fitter(det, payload.field_data, cfg);
3441

35-
auto params = track_states.at(param_id).header.fit_params;
42+
auto track = tracks.at(param_id);
43+
auto params = track.params();
3644

3745
// TODO: Merge into filter?
3846
inflate_covariance(params, fitter.config().covariance_inflation_factor);
3947

4048
typename fitter_t::state fitter_state(
41-
track_states.at(param_id).items,
49+
track, track_states, measurements,
4250
*(payload.barcodes_view.ptr() + param_id));
4351

4452
kalman_fitter_status fit_status = fitter.filter(params, fitter_state);

device/common/include/traccc/fitting/device/fit_prelude.hpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,66 @@
77

88
#pragma once
99

10+
// Local include(s).
11+
#include "traccc/device/global_index.hpp"
12+
13+
// Project include(s).
1014
#include "traccc/edm/track_candidate_container.hpp"
11-
#include "traccc/fitting/device/fit.hpp"
15+
#include "traccc/edm/track_fit_container.hpp"
1216
#include "traccc/fitting/status_codes.hpp"
1317

18+
// VecMem include(s).
19+
#include <vecmem/containers/data/vector_view.hpp>
20+
1421
namespace traccc::device {
1522

23+
template <typename algebra_t>
1624
TRACCC_HOST_DEVICE inline void fit_prelude(
1725
const global_index_t globalIndex,
1826
vecmem::data::vector_view<const unsigned int> param_ids_view,
19-
edm::track_candidate_container<default_algebra>::const_view
27+
typename edm::track_candidate_container<algebra_t>::const_view
2028
track_candidates_view,
21-
track_state_container_types::view track_states_view,
29+
typename edm::track_fit_container<algebra_t>::view tracks_view,
2230
vecmem::data::vector_view<unsigned int> param_liveness_view) {
2331

24-
track_state_container_types::device track_states(track_states_view);
32+
typename edm::track_fit_collection<algebra_t>::device tracks(
33+
tracks_view.tracks);
2534

26-
if (globalIndex >= track_states.size()) {
35+
if (globalIndex >= tracks.size()) {
2736
return;
2837
}
2938

39+
typename edm::track_state_collection<algebra_t>::device track_states(
40+
tracks_view.states);
41+
3042
vecmem::device_vector<const unsigned int> param_ids(param_ids_view);
3143
vecmem::device_vector<unsigned int> param_liveness(param_liveness_view);
3244

3345
const unsigned int param_id = param_ids.at(globalIndex);
3446

35-
auto track_states_per_track = track_states.at(param_id).items;
47+
auto track = tracks.at(param_id);
3648

37-
const edm::track_candidate_collection<default_algebra>::const_device
49+
const typename edm::track_candidate_collection<algebra_t>::const_device
3850
track_candidates{track_candidates_view.tracks};
51+
const auto track_candidate = track_candidates.at(param_id);
52+
const auto track_candidate_measurement_indices =
53+
track_candidate.measurement_indices();
3954
const measurement_collection_types::const_device measurements{
4055
track_candidates_view.measurements};
41-
for (unsigned int meas_idx :
42-
track_candidates.measurement_indices().at(param_id)) {
43-
track_states_per_track.emplace_back(measurements.at(meas_idx));
56+
for (unsigned int meas_idx : track_candidate_measurement_indices) {
57+
const unsigned int track_state_index =
58+
track_states.push_back({0, 0.f, 0.f, 0.f, {}, {}, meas_idx});
59+
auto state = track_states.at(track_state_index);
60+
state.set_hole(true);
61+
state.set_smoothed(false);
62+
const auto surface_link = measurements.at(meas_idx).surface_link;
63+
state.filtered_params().set_surface_link(surface_link);
64+
state.smoothed_params().set_surface_link(surface_link);
65+
track.state_indices().push_back(track_state_index);
4466
}
4567

4668
// TODO: Set other stuff in the header?
47-
track_states.at(param_id).header.fit_params =
48-
track_candidates.at(param_id).params();
69+
track.params() = track_candidate.params();
4970
param_liveness.at(param_id) = 1u;
5071
}
5172

device/cuda/include/traccc/cuda/fitting/kalman_fitting_algorithm.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Project include(s).
1414
#include "traccc/bfield/magnetic_field.hpp"
1515
#include "traccc/edm/track_candidate_container.hpp"
16-
#include "traccc/edm/track_state.hpp"
16+
#include "traccc/edm/track_fit_container.hpp"
1717
#include "traccc/fitting/fitting_config.hpp"
1818
#include "traccc/geometry/detector.hpp"
1919
#include "traccc/utils/algorithm.hpp"
@@ -30,10 +30,10 @@ namespace traccc::cuda {
3030

3131
/// Kalman filter based track fitting algorithm
3232
class kalman_fitting_algorithm
33-
: public algorithm<track_state_container_types::buffer(
33+
: public algorithm<edm::track_fit_container<default_algebra>::buffer(
3434
const default_detector::view&, const magnetic_field&,
3535
const edm::track_candidate_container<default_algebra>::const_view&)>,
36-
public algorithm<track_state_container_types::buffer(
36+
public algorithm<edm::track_fit_container<default_algebra>::buffer(
3737
const telescope_detector::view&, const magnetic_field&,
3838
const edm::track_candidate_container<default_algebra>::const_view&)>,
3939
public messaging {
@@ -42,7 +42,7 @@ class kalman_fitting_algorithm
4242
/// Configuration type
4343
using config_type = fitting_config;
4444
/// Output type
45-
using output_type = track_state_container_types::buffer;
45+
using output_type = edm::track_fit_container<default_algebra>::buffer;
4646

4747
/// Constructor with the algorithm's configuration
4848
///

device/cuda/src/fitting/kalman_fitting.cuh

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
// Project include(s).
2020
#include "traccc/edm/device/sort_key.hpp"
2121
#include "traccc/edm/track_candidate_container.hpp"
22-
#include "traccc/edm/track_state.hpp"
22+
#include "traccc/edm/track_fit_container.hpp"
2323
#include "traccc/fitting/details/kalman_fitting_types.hpp"
2424
#include "traccc/fitting/device/fill_fitting_sort_keys.hpp"
2525
#include "traccc/fitting/fitting_config.hpp"
@@ -32,6 +32,9 @@
3232
#include <thrust/execution_policy.h>
3333
#include <thrust/sort.h>
3434

35+
// System include(s).
36+
#include <numeric>
37+
3538
namespace traccc::cuda::details {
3639

3740
/// Templated implementation of the CUDA track fitting algorithm.
@@ -50,7 +53,8 @@ namespace traccc::cuda::details {
5053
/// @return A container of the fitted track states
5154
///
5255
template <typename detector_t, typename bfield_t>
53-
track_state_container_types::buffer kalman_fitting(
56+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
57+
kalman_fitting(
5458
const typename detector_t::view_type& det_view, const bfield_t& field_view,
5559
const typename edm::track_candidate_container<
5660
typename detector_t::algebra_type>::const_view& track_candidates_view,
@@ -68,14 +72,17 @@ track_state_container_types::buffer kalman_fitting(
6872
// Get the sizes of the track candidates in each track.
6973
const std::vector<unsigned int> candidate_sizes =
7074
copy.get_sizes(track_candidates_view.tracks);
75+
const unsigned int n_states =
76+
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
7177

7278
// Create the result buffer.
73-
track_state_container_types::buffer track_states_buffer{
74-
{n_tracks, mr.main},
75-
{candidate_sizes, mr.main, mr.host,
76-
vecmem::data::buffer_type::resizable}};
77-
copy.setup(track_states_buffer.headers)->ignore();
78-
copy.setup(track_states_buffer.items)->ignore();
79+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
80+
track_states_buffer{
81+
{candidate_sizes, mr.main, mr.host,
82+
vecmem::data::buffer_type::resizable},
83+
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
84+
copy.setup(track_states_buffer.tracks)->ignore();
85+
copy.setup(track_states_buffer.states)->ignore();
7986

8087
// Return early, if there are no tracks.
8188
if (n_tracks == 0) {
@@ -127,7 +134,9 @@ track_state_container_types::buffer kalman_fitting(
127134

128135
// Run the fitting, using the sorted parameter IDs.
129136
fit_prelude(nBlocks, nThreads, 0, stream, param_ids_buffer,
130-
track_candidates_view, track_states_buffer,
137+
track_candidates_view,
138+
{track_states_buffer.tracks, track_states_buffer.states,
139+
track_candidates_view.measurements},
131140
param_liveness_buffer);
132141
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
133142
str.synchronize();
@@ -139,7 +148,8 @@ track_state_container_types::buffer kalman_fitting(
139148
.field_data = field_view,
140149
.param_ids_view = param_ids_buffer,
141150
.param_liveness_view = param_liveness_buffer,
142-
.track_states_view = track_states_buffer,
151+
.tracks_view = {track_states_buffer.tracks, track_states_buffer.states,
152+
track_candidates_view.measurements},
143153
.barcodes_view = seqs_buffer};
144154

145155
for (std::size_t i = 0; i < config.n_iterations; ++i) {

device/cuda/src/fitting/kernels/fit_prelude.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ __global__ void fit_prelude(
1515
vecmem::data::vector_view<const unsigned int> param_ids_view,
1616
edm::track_candidate_container<default_algebra>::const_view
1717
track_candidates_view,
18-
track_state_container_types::view track_states_view,
18+
edm::track_fit_container<default_algebra>::view tracks_view,
1919
vecmem::data::vector_view<unsigned int> param_liveness_view) {
20-
device::fit_prelude(details::global_index1(), param_ids_view,
21-
track_candidates_view, track_states_view,
22-
param_liveness_view);
20+
device::fit_prelude<default_algebra>(details::global_index1(),
21+
param_ids_view, track_candidates_view,
22+
tracks_view, param_liveness_view);
2323
}
2424
} // namespace kernels
2525

@@ -28,10 +28,10 @@ void fit_prelude(const dim3& grid_size, const dim3& block_size,
2828
vecmem::data::vector_view<const unsigned int> param_ids_view,
2929
edm::track_candidate_container<default_algebra>::const_view
3030
track_candidates_view,
31-
track_state_container_types::view track_states_view,
31+
edm::track_fit_container<default_algebra>::view tracks_view,
3232
vecmem::data::vector_view<unsigned int> param_liveness_view) {
3333
kernels::fit_prelude<<<grid_size, block_size, shared_mem_size, stream>>>(
34-
param_ids_view, track_candidates_view, track_states_view,
34+
param_ids_view, track_candidates_view, tracks_view,
3535
param_liveness_view);
3636
}
3737
} // namespace traccc::cuda

device/cuda/src/fitting/kernels/fit_prelude.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
#include <cuda_runtime.h>
1111

1212
#include "traccc/edm/track_candidate_container.hpp"
13-
#include "traccc/edm/track_state.hpp"
13+
#include "traccc/edm/track_fit_container.hpp"
1414

1515
namespace traccc::cuda {
1616
void fit_prelude(const dim3& grid_size, const dim3& block_size,
1717
std::size_t shared_mem_size, const cudaStream_t& stream,
1818
vecmem::data::vector_view<const unsigned int> param_ids_view,
1919
edm::track_candidate_container<default_algebra>::const_view
2020
track_candidates_view,
21-
track_state_container_types::view track_states_view,
21+
edm::track_fit_container<default_algebra>::view tracks_view,
2222
vecmem::data::vector_view<unsigned int> param_liveness_view);
2323
}

0 commit comments

Comments
 (0)