Skip to content

Commit 9a8e8df

Browse files
committed
Allow fitting algorithms to use prefitted tracks
This commit enables our fitting algorithms to use prefitted tracks, such that we can reuse the fitting information from the CKF. Note that we also preserve the ability to work with unfitted tracks.
1 parent a7aa08d commit 9a8e8df

File tree

16 files changed

+504
-185
lines changed

16 files changed

+504
-185
lines changed

device/alpaka/include/traccc/alpaka/fitting/kalman_fitting_algorithm.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class kalman_fitting_algorithm
3434
: public algorithm<edm::track_fit_container<default_algebra>::buffer(
3535
const detector_buffer&, const magnetic_field&,
3636
const edm::track_candidate_container<default_algebra>::const_view&)>,
37+
public algorithm<edm::track_fit_container<default_algebra>::buffer(
38+
const detector_buffer&, const magnetic_field&,
39+
edm::track_fit_container<default_algebra>::buffer&&,
40+
const measurement_collection_types::const_view&)>,
3741
public messaging {
3842

3943
public:
@@ -55,7 +59,7 @@ class kalman_fitting_algorithm
5559
vecmem::copy& copy, queue& q,
5660
std::unique_ptr<const Logger> logger = getDummyLogger().clone());
5761

58-
/// Execute the algorithm
62+
/// Execute the algorithm using unfitted tracks
5963
///
6064
/// @param det The detector object
6165
/// @param bfield The magnetic field object
@@ -68,6 +72,20 @@ class kalman_fitting_algorithm
6872
const edm::track_candidate_container<default_algebra>::const_view&
6973
track_candidates) const override;
7074

75+
/// Execute the algorithm using fitted tracks
76+
///
77+
/// @param det The detector object
78+
/// @param bfield The magnetic field object
79+
/// @param track_candidates All track candidates to fit
80+
///
81+
/// @return A container of the fitted track states
82+
///
83+
output_type operator()(
84+
const detector_buffer& det, const magnetic_field& bfield,
85+
edm::track_fit_container<default_algebra>::buffer&& track_states,
86+
const measurement_collection_types::const_view& measurements)
87+
const override;
88+
7189
private:
7290
/// Algorithm configuration
7391
config_type m_config;

device/alpaka/src/fitting/kalman_fitting.hpp

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ struct fill_fitting_sort_keys {
3535
template <typename TAcc>
3636
ALPAKA_FN_ACC void operator()(
3737
TAcc const& acc,
38-
edm::track_candidate_collection<default_algebra>::const_view
39-
track_candidates_view,
38+
edm::track_fit_collection<default_algebra>::const_view track_fit_view,
4039
vecmem::data::vector_view<device::sort_key> keys_view,
4140
vecmem::data::vector_view<unsigned int> ids_view) const {
4241

4342
const device::global_index_t globalThreadIdx =
4443
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
45-
device::fill_fitting_sort_keys(globalThreadIdx, track_candidates_view,
44+
device::fill_fitting_sort_keys(globalThreadIdx, track_fit_view,
4645
keys_view, ids_view);
4746
}
4847
};
@@ -52,17 +51,15 @@ struct fit_prelude {
5251
template <typename TAcc>
5352
ALPAKA_FN_ACC void operator()(
5453
TAcc const& acc,
55-
vecmem::data::vector_view<const unsigned int> param_ids_view,
5654
edm::track_candidate_container<default_algebra>::const_view
5755
track_candidates_view,
58-
edm::track_fit_container<default_algebra>::view track_states_view,
59-
vecmem::data::vector_view<unsigned int> param_liveness_view) const {
56+
edm::track_fit_container<default_algebra>::view track_states_view)
57+
const {
6058

6159
const device::global_index_t globalThreadIdx =
6260
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
6361
device::fit_prelude<default_algebra>(
64-
globalThreadIdx, param_ids_view, track_candidates_view,
65-
track_states_view, param_liveness_view);
62+
globalThreadIdx, track_candidates_view, track_states_view);
6663
}
6764
};
6865

@@ -96,14 +93,15 @@ struct fit_backward {
9693

9794
} // namespace kernels
9895

99-
/// Templated implementation of the Alpaka track fitting algorithm.
96+
/// Templated implementation of the Alpaka track fitting algorithm for
97+
/// fitted tracks.
10098
///
10199
/// @tparam detector_t The (device) detector type to use
102100
/// @tparam bfield_t The magnetic field type to use
103101
///
104102
/// @param[in] det_view A view of the detector geometry
105103
/// @param[in] field_view A view of the magnetic field
106-
/// @param[in] track_candidates_view All track candidates to fit
104+
/// @param[in] track_fit_view All track candidates to fit
107105
/// @param[in] config The fitting configuration
108106
/// @param[in] mr Memory resource(s) to use
109107
/// @param[in] copy The copy object to use for memory transfers
@@ -116,41 +114,32 @@ typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
116114
kalman_fitting(
117115
const typename detector_t::const_view_type& det_view,
118116
const bfield_t& field_view,
119-
const typename edm::track_candidate_container<
120-
typename detector_t::algebra_type>::const_view& track_candidates_view,
117+
typename edm::track_fit_container<
118+
typename detector_t::algebra_type>::buffer&& track_fit_buffer,
119+
const measurement_collection_types::const_view& measurements,
121120
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
122-
Queue& queue) {
121+
Queue& queue, bool forward_on_first_iteration = false) {
123122

124123
// Number of threads per block to use.
125124
const Idx threadsPerBlock = getWarpSize<Acc>() * 2;
126125

126+
typename edm::track_fit_container<
127+
typename detector_t::algebra_type>::const_view track_fit_view{
128+
vecmem::get_data(track_fit_buffer.tracks),
129+
vecmem::get_data(track_fit_buffer.states), measurements};
130+
127131
// Get the number of tracks.
128132
const edm::track_candidate_collection<
129133
default_algebra>::const_device::size_type n_tracks =
130-
copy.get_size(track_candidates_view.tracks);
134+
copy.get_size(track_fit_view.tracks);
131135

132136
// Get the sizes of the track candidates in each track.
133137
const std::vector<unsigned int> candidate_sizes =
134-
copy.get_sizes(track_candidates_view.tracks);
135-
const unsigned int n_states =
136-
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
137-
138-
// Create the result buffer.
139-
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
140-
track_states_buffer{
141-
{candidate_sizes, mr.main, mr.host,
142-
vecmem::data::buffer_type::resizable},
143-
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
144-
vecmem::copy::event_type tracks_setup_event =
145-
copy.setup(track_states_buffer.tracks);
146-
vecmem::copy::event_type track_states_setup_event =
147-
copy.setup(track_states_buffer.states);
138+
copy.get_sizes(track_fit_view.tracks);
148139

149140
// Return early, if there are no tracks.
150141
if (n_tracks == 0) {
151-
tracks_setup_event->wait();
152-
track_states_setup_event->wait();
153-
return track_states_buffer;
142+
return track_fit_buffer;
154143
}
155144

156145
std::vector<unsigned int> seqs_sizes(candidate_sizes.size());
@@ -178,6 +167,7 @@ kalman_fitting(
178167
keys_setup_event->wait();
179168
param_ids_setup_event->wait();
180169
param_liveness_setup_event->wait();
170+
copy.memset(param_liveness_buffer, 1)->ignore();
181171

182172
// The execution range for the two kernels of the function.
183173
const Idx blocksPerGrid =
@@ -186,8 +176,7 @@ kalman_fitting(
186176

187177
// Fill the keys and param_ids buffers.
188178
::alpaka::exec<Acc>(queue, workDiv, kernels::fill_fitting_sort_keys{},
189-
track_candidates_view.tracks,
190-
vecmem::get_data(keys_buffer),
179+
track_fit_view.tracks, vecmem::get_data(keys_buffer),
191180
vecmem::get_data(param_ids_buffer));
192181
::alpaka::wait(queue);
193182

@@ -197,28 +186,15 @@ kalman_fitting(
197186
details::sort_by_key(queue, mr, keys_device.begin(), keys_device.end(),
198187
param_ids_device.begin());
199188

200-
// Run the fitting, using the sorted parameter IDs.
201-
typename edm::track_fit_container<typename detector_t::algebra_type>::view
202-
track_states_view{track_states_buffer.tracks,
203-
track_states_buffer.states,
204-
track_candidates_view.measurements};
205-
tracks_setup_event->wait();
206-
track_states_setup_event->wait();
207-
208-
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
209-
vecmem::get_data(param_ids_buffer),
210-
track_candidates_view, track_states_view,
211-
vecmem::get_data(param_liveness_buffer));
212-
::alpaka::wait(queue);
213-
214189
// Allocate the fitting kernels's payload in host memory.
215190
using fitter_t = traccc::details::kalman_fitter_t<detector_t, bfield_t>;
216191
device::fit_payload<fitter_t> host_payload{
217192
.det_data = det_view,
218193
.field_data = field_view,
219194
.param_ids_view = param_ids_buffer,
220195
.param_liveness_view = param_liveness_buffer,
221-
.tracks_view = track_states_view,
196+
.tracks_view = {track_fit_buffer.tracks, track_fit_buffer.states,
197+
measurements},
222198
.barcodes_view = seqs_buffer};
223199
// Now copy it to device memory.
224200
vecmem::data::vector_buffer<device::fit_payload<fitter_t>> device_payload(
@@ -231,16 +207,113 @@ kalman_fitting(
231207

232208
for (std::size_t i = 0; i < config.n_iterations; ++i) {
233209
// Run the track fitting
234-
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_forward<fitter_t>{},
235-
config, device_payload.ptr());
236-
::alpaka::wait(queue);
210+
if (i > 0 || forward_on_first_iteration) {
211+
::alpaka::exec<Acc>(queue, workDiv,
212+
kernels::fit_forward<fitter_t>{}, config,
213+
device_payload.ptr());
214+
::alpaka::wait(queue);
215+
}
237216
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_backward<fitter_t>{},
238217
config, device_payload.ptr());
239218
::alpaka::wait(queue);
240219
}
241220

242221
// Return the fitted tracks.
243-
return track_states_buffer;
222+
return track_fit_buffer;
223+
}
224+
225+
/// Templated implementation of the Alpaka track fitting algorithm for
226+
/// unfitted tracks.
227+
///
228+
/// @tparam detector_t The (device) detector type to use
229+
/// @tparam bfield_t The magnetic field type to use
230+
///
231+
/// @param[in] det_view A view of the detector geometry
232+
/// @param[in] field_view A view of the magnetic field
233+
/// @param[in] track_candidates_view All track candidates to fit
234+
/// @param[in] config The fitting configuration
235+
/// @param[in] mr Memory resource(s) to use
236+
/// @param[in] copy The copy object to use for memory transfers
237+
/// @param[in] queue The Alpaka queue to use for execution
238+
///
239+
/// @return A container of the fitted track states
240+
///
241+
template <typename detector_t, typename bfield_t>
242+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
243+
kalman_fitting(
244+
const typename detector_t::const_view_type& det_view,
245+
const bfield_t& field_view,
246+
const typename edm::track_candidate_container<
247+
typename detector_t::algebra_type>::const_view& track_candidates_view,
248+
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
249+
Queue& queue) {
250+
251+
// Number of threads per block to use.
252+
const Idx threadsPerBlock = getWarpSize<Acc>() * 2;
253+
254+
// Get the number of tracks.
255+
const edm::track_candidate_collection<
256+
default_algebra>::const_device::size_type n_tracks =
257+
copy.get_size(track_candidates_view.tracks);
258+
259+
// Get the sizes of the track candidates in each track.
260+
const std::vector<unsigned int> candidate_sizes =
261+
copy.get_sizes(track_candidates_view.tracks);
262+
const unsigned int n_states =
263+
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
264+
265+
// Create the result buffer.
266+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
267+
track_states_buffer{
268+
{candidate_sizes, mr.main, mr.host,
269+
vecmem::data::buffer_type::resizable},
270+
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
271+
vecmem::copy::event_type tracks_setup_event =
272+
copy.setup(track_states_buffer.tracks);
273+
vecmem::copy::event_type track_states_setup_event =
274+
copy.setup(track_states_buffer.states);
275+
276+
// Return early, if there are no tracks.
277+
if (n_tracks == 0) {
278+
tracks_setup_event->wait();
279+
track_states_setup_event->wait();
280+
return track_states_buffer;
281+
}
282+
283+
std::vector<unsigned int> seqs_sizes(candidate_sizes.size());
284+
std::transform(candidate_sizes.begin(), candidate_sizes.end(),
285+
seqs_sizes.begin(), [&config](const unsigned int sz) {
286+
return std::max(sz * config.barcode_sequence_size_factor,
287+
config.min_barcode_sequence_capacity);
288+
});
289+
vecmem::data::jagged_vector_buffer<detray::geometry::barcode> seqs_buffer{
290+
seqs_sizes, mr.main, mr.host, vecmem::data::buffer_type::resizable};
291+
copy.setup(seqs_buffer)->wait();
292+
293+
// The execution range for the two kernels of the function.
294+
const Idx blocksPerGrid =
295+
(n_tracks + threadsPerBlock - 1) / threadsPerBlock;
296+
const auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
297+
298+
// Run the fitting, using the sorted parameter IDs.
299+
typename edm::track_fit_container<typename detector_t::algebra_type>::view
300+
track_states_view{track_states_buffer.tracks,
301+
track_states_buffer.states,
302+
track_candidates_view.measurements};
303+
tracks_setup_event->wait();
304+
track_states_setup_event->wait();
305+
306+
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
307+
track_candidates_view, track_states_view);
308+
::alpaka::wait(queue);
309+
310+
return kalman_fitting<detector_t, bfield_t>(
311+
det_view, field_view,
312+
typename edm::track_fit_container<
313+
typename detector_t::algebra_type>::buffer{
314+
std::move(track_states_buffer.tracks),
315+
std::move(track_states_buffer.states)},
316+
track_candidates_view.measurements, config, mr, copy, queue, true);
244317
}
245318

246319
} // namespace traccc::alpaka::details

device/alpaka/src/fitting/kalman_fitting_algorithm.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,23 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
4444
});
4545
}
4646

47+
kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
48+
const detector_buffer& det, const magnetic_field& bfield,
49+
edm::track_fit_container<default_algebra>::buffer&& track_states,
50+
const measurement_collection_types::const_view& measurements) const {
51+
52+
// Run the track fitting.
53+
return detector_buffer_magnetic_field_visitor<
54+
detector_type_list, alpaka::bfield_type_list<scalar>>(
55+
det, bfield,
56+
[&]<typename detector_t, typename bfield_view_t>(
57+
const typename detector_t::view& detector,
58+
const bfield_view_t& field) {
59+
return details::kalman_fitting<typename detector_t::device>(
60+
detector, field, std::move(track_states), measurements,
61+
m_config, m_mr, m_copy.get(),
62+
details::get_queue(m_queue.get()));
63+
});
64+
}
65+
4766
} // namespace traccc::alpaka

device/common/include/traccc/fitting/device/fill_fitting_sort_keys.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
#include "traccc/edm/device/sort_key.hpp"
1313

1414
// Project include(s).
15-
#include "traccc/edm/track_candidate_collection.hpp"
15+
#include "traccc/edm/track_fit_collection.hpp"
1616

1717
namespace traccc::device {
1818

1919
/// Function used to fill key container
2020
///
2121
/// @param[in] globalIndex The index of the current thread
22-
/// @param[in] track_candidates_view The input track candidates
22+
/// @param[in] track_fit_view The input track states
2323
/// @param[out] keys_view The key values
2424
/// @param[out] ids_view The param ids
2525
///
2626
TRACCC_HOST_DEVICE inline void fill_fitting_sort_keys(
2727
global_index_t globalIndex,
28-
const edm::track_candidate_collection<default_algebra>::const_view&
29-
track_candidates_view,
28+
const edm::track_fit_collection<default_algebra>::const_view&
29+
track_fit_view,
3030
vecmem::data::vector_view<device::sort_key> keys_view,
3131
vecmem::data::vector_view<unsigned int> ids_view);
3232

0 commit comments

Comments
 (0)