Skip to content

Commit 1a7489d

Browse files
committed
Allow fitting algorithms to use prefitted tracks
This commit enables our fitting algorithms to use prefitted tracks, such that we can reuse the fitting information from the CKF. Note that we also preserve the ability to work with unfitted tracks.
1 parent a7aa08d commit 1a7489d

File tree

16 files changed

+501
-185
lines changed

16 files changed

+501
-185
lines changed

device/alpaka/include/traccc/alpaka/fitting/kalman_fitting_algorithm.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class kalman_fitting_algorithm
3434
: public algorithm<edm::track_fit_container<default_algebra>::buffer(
3535
const detector_buffer&, const magnetic_field&,
3636
const edm::track_candidate_container<default_algebra>::const_view&)>,
37+
public algorithm<edm::track_fit_container<default_algebra>::buffer(
38+
const detector_buffer&, const magnetic_field&,
39+
edm::track_fit_container<default_algebra>::buffer&&,
40+
const measurement_collection_types::const_view&)>,
3741
public messaging {
3842

3943
public:
@@ -55,7 +59,7 @@ class kalman_fitting_algorithm
5559
vecmem::copy& copy, queue& q,
5660
std::unique_ptr<const Logger> logger = getDummyLogger().clone());
5761

58-
/// Execute the algorithm
62+
/// Execute the algorithm using unfitted tracks
5963
///
6064
/// @param det The detector object
6165
/// @param bfield The magnetic field object
@@ -68,6 +72,20 @@ class kalman_fitting_algorithm
6872
const edm::track_candidate_container<default_algebra>::const_view&
6973
track_candidates) const override;
7074

75+
/// Execute the algorithm using fitted tracks
76+
///
77+
/// @param det The detector object
78+
/// @param bfield The magnetic field object
79+
/// @param track_candidates All track candidates to fit
80+
///
81+
/// @return A container of the fitted track states
82+
///
83+
output_type operator()(
84+
const detector_buffer& det, const magnetic_field& bfield,
85+
edm::track_fit_container<default_algebra>::buffer&& track_states,
86+
const measurement_collection_types::const_view& measurements)
87+
const override;
88+
7189
private:
7290
/// Algorithm configuration
7391
config_type m_config;

device/alpaka/src/fitting/kalman_fitting.hpp

Lines changed: 124 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ struct fill_fitting_sort_keys {
3535
template <typename TAcc>
3636
ALPAKA_FN_ACC void operator()(
3737
TAcc const& acc,
38-
edm::track_candidate_collection<default_algebra>::const_view
39-
track_candidates_view,
38+
edm::track_fit_collection<default_algebra>::const_view track_fit_view,
4039
vecmem::data::vector_view<device::sort_key> keys_view,
4140
vecmem::data::vector_view<unsigned int> ids_view) const {
4241

4342
const device::global_index_t globalThreadIdx =
4443
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
45-
device::fill_fitting_sort_keys(globalThreadIdx, track_candidates_view,
44+
device::fill_fitting_sort_keys(globalThreadIdx, track_fit_view,
4645
keys_view, ids_view);
4746
}
4847
};
@@ -52,17 +51,15 @@ struct fit_prelude {
5251
template <typename TAcc>
5352
ALPAKA_FN_ACC void operator()(
5453
TAcc const& acc,
55-
vecmem::data::vector_view<const unsigned int> param_ids_view,
5654
edm::track_candidate_container<default_algebra>::const_view
5755
track_candidates_view,
58-
edm::track_fit_container<default_algebra>::view track_states_view,
59-
vecmem::data::vector_view<unsigned int> param_liveness_view) const {
56+
edm::track_fit_container<default_algebra>::view track_states_view)
57+
const {
6058

6159
const device::global_index_t globalThreadIdx =
6260
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
6361
device::fit_prelude<default_algebra>(
64-
globalThreadIdx, param_ids_view, track_candidates_view,
65-
track_states_view, param_liveness_view);
62+
globalThreadIdx, track_candidates_view, track_states_view);
6663
}
6764
};
6865

@@ -96,14 +93,15 @@ struct fit_backward {
9693

9794
} // namespace kernels
9895

99-
/// Templated implementation of the Alpaka track fitting algorithm.
96+
/// Templated implementation of the Alpaka track fitting algorithm for
97+
/// fitted tracks.
10098
///
10199
/// @tparam detector_t The (device) detector type to use
102100
/// @tparam bfield_t The magnetic field type to use
103101
///
104102
/// @param[in] det_view A view of the detector geometry
105103
/// @param[in] field_view A view of the magnetic field
106-
/// @param[in] track_candidates_view All track candidates to fit
104+
/// @param[in] track_fit_view All track candidates to fit
107105
/// @param[in] config The fitting configuration
108106
/// @param[in] mr Memory resource(s) to use
109107
/// @param[in] copy The copy object to use for memory transfers
@@ -116,41 +114,32 @@ typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
116114
kalman_fitting(
117115
const typename detector_t::const_view_type& det_view,
118116
const bfield_t& field_view,
119-
const typename edm::track_candidate_container<
120-
typename detector_t::algebra_type>::const_view& track_candidates_view,
117+
typename edm::track_fit_container<
118+
typename detector_t::algebra_type>::buffer&& track_fit_buffer,
119+
const measurement_collection_types::const_view& measurements,
121120
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
122-
Queue& queue) {
121+
Queue& queue, bool forward_on_first_iteration = false) {
123122

124123
// Number of threads per block to use.
125124
const Idx threadsPerBlock = getWarpSize<Acc>() * 2;
126125

126+
typename edm::track_fit_container<
127+
typename detector_t::algebra_type>::const_view track_fit_view{
128+
vecmem::get_data(track_fit_buffer.tracks),
129+
vecmem::get_data(track_fit_buffer.states), measurements};
130+
127131
// Get the number of tracks.
128132
const edm::track_candidate_collection<
129133
default_algebra>::const_device::size_type n_tracks =
130-
copy.get_size(track_candidates_view.tracks);
134+
copy.get_size(track_fit_view.tracks);
131135

132136
// Get the sizes of the track candidates in each track.
133137
const std::vector<unsigned int> candidate_sizes =
134-
copy.get_sizes(track_candidates_view.tracks);
135-
const unsigned int n_states =
136-
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
137-
138-
// Create the result buffer.
139-
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
140-
track_states_buffer{
141-
{candidate_sizes, mr.main, mr.host,
142-
vecmem::data::buffer_type::resizable},
143-
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
144-
vecmem::copy::event_type tracks_setup_event =
145-
copy.setup(track_states_buffer.tracks);
146-
vecmem::copy::event_type track_states_setup_event =
147-
copy.setup(track_states_buffer.states);
138+
copy.get_sizes(track_fit_view.tracks);
148139

149140
// Return early, if there are no tracks.
150141
if (n_tracks == 0) {
151-
tracks_setup_event->wait();
152-
track_states_setup_event->wait();
153-
return track_states_buffer;
142+
return track_fit_buffer;
154143
}
155144

156145
std::vector<unsigned int> seqs_sizes(candidate_sizes.size());
@@ -186,8 +175,7 @@ kalman_fitting(
186175

187176
// Fill the keys and param_ids buffers.
188177
::alpaka::exec<Acc>(queue, workDiv, kernels::fill_fitting_sort_keys{},
189-
track_candidates_view.tracks,
190-
vecmem::get_data(keys_buffer),
178+
track_fit_view.tracks, vecmem::get_data(keys_buffer),
191179
vecmem::get_data(param_ids_buffer));
192180
::alpaka::wait(queue);
193181

@@ -197,28 +185,15 @@ kalman_fitting(
197185
details::sort_by_key(queue, mr, keys_device.begin(), keys_device.end(),
198186
param_ids_device.begin());
199187

200-
// Run the fitting, using the sorted parameter IDs.
201-
typename edm::track_fit_container<typename detector_t::algebra_type>::view
202-
track_states_view{track_states_buffer.tracks,
203-
track_states_buffer.states,
204-
track_candidates_view.measurements};
205-
tracks_setup_event->wait();
206-
track_states_setup_event->wait();
207-
208-
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
209-
vecmem::get_data(param_ids_buffer),
210-
track_candidates_view, track_states_view,
211-
vecmem::get_data(param_liveness_buffer));
212-
::alpaka::wait(queue);
213-
214188
// Allocate the fitting kernels's payload in host memory.
215189
using fitter_t = traccc::details::kalman_fitter_t<detector_t, bfield_t>;
216190
device::fit_payload<fitter_t> host_payload{
217191
.det_data = det_view,
218192
.field_data = field_view,
219193
.param_ids_view = param_ids_buffer,
220194
.param_liveness_view = param_liveness_buffer,
221-
.tracks_view = track_states_view,
195+
.tracks_view = {track_fit_buffer.tracks, track_fit_buffer.states,
196+
measurements},
222197
.barcodes_view = seqs_buffer};
223198
// Now copy it to device memory.
224199
vecmem::data::vector_buffer<device::fit_payload<fitter_t>> device_payload(
@@ -231,16 +206,113 @@ kalman_fitting(
231206

232207
for (std::size_t i = 0; i < config.n_iterations; ++i) {
233208
// Run the track fitting
234-
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_forward<fitter_t>{},
235-
config, device_payload.ptr());
236-
::alpaka::wait(queue);
209+
if (i > 0 || forward_on_first_iteration) {
210+
::alpaka::exec<Acc>(queue, workDiv,
211+
kernels::fit_forward<fitter_t>{}, config,
212+
device_payload.ptr());
213+
::alpaka::wait(queue);
214+
}
237215
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_backward<fitter_t>{},
238216
config, device_payload.ptr());
239217
::alpaka::wait(queue);
240218
}
241219

242220
// Return the fitted tracks.
243-
return track_states_buffer;
221+
return track_fit_buffer;
222+
}
223+
224+
/// Templated implementation of the Alpaka track fitting algorithm for
225+
/// unfitted tracks.
226+
///
227+
/// @tparam detector_t The (device) detector type to use
228+
/// @tparam bfield_t The magnetic field type to use
229+
///
230+
/// @param[in] det_view A view of the detector geometry
231+
/// @param[in] field_view A view of the magnetic field
232+
/// @param[in] track_candidates_view All track candidates to fit
233+
/// @param[in] config The fitting configuration
234+
/// @param[in] mr Memory resource(s) to use
235+
/// @param[in] copy The copy object to use for memory transfers
236+
/// @param[in] queue The Alpaka queue to use for execution
237+
///
238+
/// @return A container of the fitted track states
239+
///
240+
template <typename detector_t, typename bfield_t>
241+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
242+
kalman_fitting(
243+
const typename detector_t::const_view_type& det_view,
244+
const bfield_t& field_view,
245+
const typename edm::track_candidate_container<
246+
typename detector_t::algebra_type>::const_view& track_candidates_view,
247+
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
248+
Queue& queue) {
249+
250+
// Number of threads per block to use.
251+
const Idx threadsPerBlock = getWarpSize<Acc>() * 2;
252+
253+
// Get the number of tracks.
254+
const edm::track_candidate_collection<
255+
default_algebra>::const_device::size_type n_tracks =
256+
copy.get_size(track_candidates_view.tracks);
257+
258+
// Get the sizes of the track candidates in each track.
259+
const std::vector<unsigned int> candidate_sizes =
260+
copy.get_sizes(track_candidates_view.tracks);
261+
const unsigned int n_states =
262+
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
263+
264+
// Create the result buffer.
265+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
266+
track_states_buffer{
267+
{candidate_sizes, mr.main, mr.host,
268+
vecmem::data::buffer_type::resizable},
269+
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
270+
vecmem::copy::event_type tracks_setup_event =
271+
copy.setup(track_states_buffer.tracks);
272+
vecmem::copy::event_type track_states_setup_event =
273+
copy.setup(track_states_buffer.states);
274+
275+
// Return early, if there are no tracks.
276+
if (n_tracks == 0) {
277+
tracks_setup_event->wait();
278+
track_states_setup_event->wait();
279+
return track_states_buffer;
280+
}
281+
282+
std::vector<unsigned int> seqs_sizes(candidate_sizes.size());
283+
std::transform(candidate_sizes.begin(), candidate_sizes.end(),
284+
seqs_sizes.begin(), [&config](const unsigned int sz) {
285+
return std::max(sz * config.barcode_sequence_size_factor,
286+
config.min_barcode_sequence_capacity);
287+
});
288+
vecmem::data::jagged_vector_buffer<detray::geometry::barcode> seqs_buffer{
289+
seqs_sizes, mr.main, mr.host, vecmem::data::buffer_type::resizable};
290+
copy.setup(seqs_buffer)->wait();
291+
292+
// The execution range for the two kernels of the function.
293+
const Idx blocksPerGrid =
294+
(n_tracks + threadsPerBlock - 1) / threadsPerBlock;
295+
const auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
296+
297+
// Run the fitting, using the sorted parameter IDs.
298+
typename edm::track_fit_container<typename detector_t::algebra_type>::view
299+
track_states_view{track_states_buffer.tracks,
300+
track_states_buffer.states,
301+
track_candidates_view.measurements};
302+
tracks_setup_event->wait();
303+
track_states_setup_event->wait();
304+
305+
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
306+
track_candidates_view, track_states_view);
307+
::alpaka::wait(queue);
308+
309+
return kalman_fitting<detector_t, bfield_t>(
310+
det_view, field_view,
311+
typename edm::track_fit_container<
312+
typename detector_t::algebra_type>::buffer{
313+
std::move(track_states_buffer.tracks),
314+
std::move(track_states_buffer.states)},
315+
track_candidates_view.measurements, config, mr, copy, queue, true);
244316
}
245317

246318
} // namespace traccc::alpaka::details

device/alpaka/src/fitting/kalman_fitting_algorithm.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,23 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
4444
});
4545
}
4646

47+
kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
48+
const detector_buffer& det, const magnetic_field& bfield,
49+
edm::track_fit_container<default_algebra>::buffer&& track_states,
50+
const measurement_collection_types::const_view& measurements) const {
51+
52+
// Run the track fitting.
53+
return detector_buffer_magnetic_field_visitor<
54+
detector_type_list, alpaka::bfield_type_list<scalar>>(
55+
det, bfield,
56+
[&]<typename detector_t, typename bfield_view_t>(
57+
const typename detector_t::view& detector,
58+
const bfield_view_t& field) {
59+
return details::kalman_fitting<typename detector_t::device>(
60+
detector, field, std::move(track_states), measurements,
61+
m_config, m_mr, m_copy.get(),
62+
details::get_queue(m_queue.get()));
63+
});
64+
}
65+
4766
} // namespace traccc::alpaka

device/common/include/traccc/fitting/device/fill_fitting_sort_keys.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
#include "traccc/edm/device/sort_key.hpp"
1313

1414
// Project include(s).
15-
#include "traccc/edm/track_candidate_collection.hpp"
15+
#include "traccc/edm/track_fit_collection.hpp"
1616

1717
namespace traccc::device {
1818

1919
/// Function used to fill key container
2020
///
2121
/// @param[in] globalIndex The index of the current thread
22-
/// @param[in] track_candidates_view The input track candidates
22+
/// @param[in] track_fit_view The input track states
2323
/// @param[out] keys_view The key values
2424
/// @param[out] ids_view The param ids
2525
///
2626
TRACCC_HOST_DEVICE inline void fill_fitting_sort_keys(
2727
global_index_t globalIndex,
28-
const edm::track_candidate_collection<default_algebra>::const_view&
29-
track_candidates_view,
28+
const edm::track_fit_collection<default_algebra>::const_view&
29+
track_fit_view,
3030
vecmem::data::vector_view<device::sort_key> keys_view,
3131
vecmem::data::vector_view<unsigned int> ids_view);
3232

device/common/include/traccc/fitting/device/fit_prelude.hpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ namespace traccc::device {
2424
template <typename algebra_t>
2525
TRACCC_HOST_DEVICE inline void fit_prelude(
2626
const global_index_t globalIndex,
27-
vecmem::data::vector_view<const unsigned int> param_ids_view,
2827
typename edm::track_candidate_container<algebra_t>::const_view
2928
track_candidates_view,
30-
typename edm::track_fit_container<algebra_t>::view tracks_view,
31-
vecmem::data::vector_view<unsigned int> param_liveness_view) {
29+
typename edm::track_fit_container<algebra_t>::view tracks_view) {
3230

3331
typename edm::track_fit_collection<algebra_t>::device tracks(
3432
tracks_view.tracks);
@@ -40,16 +38,11 @@ TRACCC_HOST_DEVICE inline void fit_prelude(
4038
typename edm::track_state_collection<algebra_t>::device track_states(
4139
tracks_view.states);
4240

43-
vecmem::device_vector<const unsigned int> param_ids(param_ids_view);
44-
vecmem::device_vector<unsigned int> param_liveness(param_liveness_view);
45-
46-
const unsigned int param_id = param_ids.at(globalIndex);
47-
48-
auto track = tracks.at(param_id);
41+
auto track = tracks.at(globalIndex);
4942

5043
const typename edm::track_candidate_collection<algebra_t>::const_device
5144
track_candidates{track_candidates_view.tracks};
52-
const auto track_candidate = track_candidates.at(param_id);
45+
const auto track_candidate = track_candidates.at(globalIndex);
5346
const auto track_candidate_measurement_indices =
5447
track_candidate.measurement_indices();
5548
const measurement_collection_types::const_device measurements{
@@ -62,7 +55,6 @@ TRACCC_HOST_DEVICE inline void fit_prelude(
6255

6356
// TODO: Set other stuff in the header?
6457
track.params() = track_candidate.params();
65-
param_liveness.at(param_id) = 1u;
6658
}
6759

6860
} // namespace traccc::device

0 commit comments

Comments
 (0)