@@ -35,14 +35,13 @@ struct fill_fitting_sort_keys {
3535 template <typename TAcc>
3636 ALPAKA_FN_ACC void operator ()(
3737 TAcc const & acc,
38- edm::track_candidate_collection<default_algebra>::const_view
39- track_candidates_view,
38+ edm::track_fit_collection<default_algebra>::const_view track_fit_view,
4039 vecmem::data::vector_view<device::sort_key> keys_view,
4140 vecmem::data::vector_view<unsigned int > ids_view) const {
4241
4342 const device::global_index_t globalThreadIdx =
4443 ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0 ];
45- device::fill_fitting_sort_keys (globalThreadIdx, track_candidates_view ,
44+ device::fill_fitting_sort_keys (globalThreadIdx, track_fit_view ,
4645 keys_view, ids_view);
4746 }
4847};
@@ -52,17 +51,15 @@ struct fit_prelude {
5251 template <typename TAcc>
5352 ALPAKA_FN_ACC void operator ()(
5453 TAcc const & acc,
55- vecmem::data::vector_view<const unsigned int > param_ids_view,
5654 edm::track_candidate_container<default_algebra>::const_view
5755 track_candidates_view,
58- edm::track_fit_container<default_algebra>::view track_states_view,
59- vecmem::data::vector_view< unsigned int > param_liveness_view) const {
56+ edm::track_fit_container<default_algebra>::view track_states_view)
57+ const {
6058
6159 const device::global_index_t globalThreadIdx =
6260 ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0 ];
6361 device::fit_prelude<default_algebra>(
64- globalThreadIdx, param_ids_view, track_candidates_view,
65- track_states_view, param_liveness_view);
62+ globalThreadIdx, track_candidates_view, track_states_view);
6663 }
6764};
6865
@@ -96,14 +93,15 @@ struct fit_backward {
9693
9794} // namespace kernels
9895
99- // / Templated implementation of the Alpaka track fitting algorithm.
96+ // / Templated implementation of the Alpaka track fitting algorithm for
97+ // / fitted tracks.
10098// /
10199// / @tparam detector_t The (device) detector type to use
102100// / @tparam bfield_t The magnetic field type to use
103101// /
104102// / @param[in] det_view A view of the detector geometry
105103// / @param[in] field_view A view of the magnetic field
106- // / @param[in] track_candidates_view All track candidates to fit
104+ // / @param[in] track_fit_view All track candidates to fit
107105// / @param[in] config The fitting configuration
108106// / @param[in] mr Memory resource(s) to use
109107// / @param[in] copy The copy object to use for memory transfers
@@ -116,41 +114,32 @@ typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
116114kalman_fitting (
117115 const typename detector_t ::const_view_type& det_view,
118116 const bfield_t & field_view,
119- const typename edm::track_candidate_container<
120- typename detector_t ::algebra_type>::const_view& track_candidates_view,
117+ typename edm::track_fit_container<
118+ typename detector_t ::algebra_type>::buffer&& track_fit_buffer,
119+ const measurement_collection_types::const_view& measurements,
121120 const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
122- Queue& queue) {
121+ Queue& queue, bool forward_on_first_iteration = false ) {
123122
124123 // Number of threads per block to use.
125124 const Idx threadsPerBlock = getWarpSize<Acc>() * 2 ;
126125
126+ typename edm::track_fit_container<
127+ typename detector_t ::algebra_type>::const_view track_fit_view{
128+ vecmem::get_data (track_fit_buffer.tracks ),
129+ vecmem::get_data (track_fit_buffer.states ), measurements};
130+
127131 // Get the number of tracks.
128132 const edm::track_candidate_collection<
129133 default_algebra>::const_device::size_type n_tracks =
130- copy.get_size (track_candidates_view .tracks );
134+ copy.get_size (track_fit_view .tracks );
131135
132136 // Get the sizes of the track candidates in each track.
133137 const std::vector<unsigned int > candidate_sizes =
134- copy.get_sizes (track_candidates_view.tracks );
135- const unsigned int n_states =
136- std::accumulate (candidate_sizes.begin (), candidate_sizes.end (), 0u );
137-
138- // Create the result buffer.
139- typename edm::track_fit_container<typename detector_t ::algebra_type>::buffer
140- track_states_buffer{
141- {candidate_sizes, mr.main , mr.host ,
142- vecmem::data::buffer_type::resizable},
143- {n_states, mr.main , vecmem::data::buffer_type::resizable}};
144- vecmem::copy::event_type tracks_setup_event =
145- copy.setup (track_states_buffer.tracks );
146- vecmem::copy::event_type track_states_setup_event =
147- copy.setup (track_states_buffer.states );
138+ copy.get_sizes (track_fit_view.tracks );
148139
149140 // Return early, if there are no tracks.
150141 if (n_tracks == 0 ) {
151- tracks_setup_event->wait ();
152- track_states_setup_event->wait ();
153- return track_states_buffer;
142+ return track_fit_buffer;
154143 }
155144
156145 std::vector<unsigned int > seqs_sizes (candidate_sizes.size ());
@@ -178,6 +167,7 @@ kalman_fitting(
178167 keys_setup_event->wait ();
179168 param_ids_setup_event->wait ();
180169 param_liveness_setup_event->wait ();
170+ copy.memset (param_liveness_buffer, 1 )->ignore ();
181171
182172 // The execution range for the two kernels of the function.
183173 const Idx blocksPerGrid =
@@ -186,8 +176,7 @@ kalman_fitting(
186176
187177 // Fill the keys and param_ids buffers.
188178 ::alpaka::exec<Acc>(queue, workDiv, kernels::fill_fitting_sort_keys{},
189- track_candidates_view.tracks ,
190- vecmem::get_data (keys_buffer),
179+ track_fit_view.tracks , vecmem::get_data (keys_buffer),
191180 vecmem::get_data (param_ids_buffer));
192181 ::alpaka::wait (queue);
193182
@@ -197,28 +186,15 @@ kalman_fitting(
197186 details::sort_by_key (queue, mr, keys_device.begin (), keys_device.end (),
198187 param_ids_device.begin ());
199188
200- // Run the fitting, using the sorted parameter IDs.
201- typename edm::track_fit_container<typename detector_t ::algebra_type>::view
202- track_states_view{track_states_buffer.tracks ,
203- track_states_buffer.states ,
204- track_candidates_view.measurements };
205- tracks_setup_event->wait ();
206- track_states_setup_event->wait ();
207-
208- ::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
209- vecmem::get_data (param_ids_buffer),
210- track_candidates_view, track_states_view,
211- vecmem::get_data (param_liveness_buffer));
212- ::alpaka::wait (queue);
213-
214189 // Allocate the fitting kernels's payload in host memory.
215190 using fitter_t = traccc::details::kalman_fitter_t <detector_t , bfield_t >;
216191 device::fit_payload<fitter_t > host_payload{
217192 .det_data = det_view,
218193 .field_data = field_view,
219194 .param_ids_view = param_ids_buffer,
220195 .param_liveness_view = param_liveness_buffer,
221- .tracks_view = track_states_view,
196+ .tracks_view = {track_fit_buffer.tracks , track_fit_buffer.states ,
197+ measurements},
222198 .barcodes_view = seqs_buffer};
223199 // Now copy it to device memory.
224200 vecmem::data::vector_buffer<device::fit_payload<fitter_t >> device_payload (
@@ -231,16 +207,113 @@ kalman_fitting(
231207
232208 for (std::size_t i = 0 ; i < config.n_iterations ; ++i) {
233209 // Run the track fitting
234- ::alpaka::exec<Acc>(queue, workDiv, kernels::fit_forward<fitter_t >{},
235- config, device_payload.ptr ());
236- ::alpaka::wait (queue);
210+ if (i > 0 || forward_on_first_iteration) {
211+ ::alpaka::exec<Acc>(queue, workDiv,
212+ kernels::fit_forward<fitter_t >{}, config,
213+ device_payload.ptr ());
214+ ::alpaka::wait (queue);
215+ }
237216 ::alpaka::exec<Acc>(queue, workDiv, kernels::fit_backward<fitter_t >{},
238217 config, device_payload.ptr ());
239218 ::alpaka::wait (queue);
240219 }
241220
242221 // Return the fitted tracks.
243- return track_states_buffer;
222+ return track_fit_buffer;
223+ }
224+
225+ // / Templated implementation of the Alpaka track fitting algorithm for
226+ // / unfitted tracks.
227+ // /
228+ // / @tparam detector_t The (device) detector type to use
229+ // / @tparam bfield_t The magnetic field type to use
230+ // /
231+ // / @param[in] det_view A view of the detector geometry
232+ // / @param[in] field_view A view of the magnetic field
233+ // / @param[in] track_candidates_view All track candidates to fit
234+ // / @param[in] config The fitting configuration
235+ // / @param[in] mr Memory resource(s) to use
236+ // / @param[in] copy The copy object to use for memory transfers
237+ // / @param[in] queue The Alpaka queue to use for execution
238+ // /
239+ // / @return A container of the fitted track states
240+ // /
241+ template <typename detector_t , typename bfield_t >
242+ typename edm::track_fit_container<typename detector_t ::algebra_type>::buffer
243+ kalman_fitting (
244+ const typename detector_t ::const_view_type& det_view,
245+ const bfield_t & field_view,
246+ const typename edm::track_candidate_container<
247+ typename detector_t ::algebra_type>::const_view& track_candidates_view,
248+ const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
249+ Queue& queue) {
250+
251+ // Number of threads per block to use.
252+ const Idx threadsPerBlock = getWarpSize<Acc>() * 2 ;
253+
254+ // Get the number of tracks.
255+ const edm::track_candidate_collection<
256+ default_algebra>::const_device::size_type n_tracks =
257+ copy.get_size (track_candidates_view.tracks );
258+
259+ // Get the sizes of the track candidates in each track.
260+ const std::vector<unsigned int > candidate_sizes =
261+ copy.get_sizes (track_candidates_view.tracks );
262+ const unsigned int n_states =
263+ std::accumulate (candidate_sizes.begin (), candidate_sizes.end (), 0u );
264+
265+ // Create the result buffer.
266+ typename edm::track_fit_container<typename detector_t ::algebra_type>::buffer
267+ track_states_buffer{
268+ {candidate_sizes, mr.main , mr.host ,
269+ vecmem::data::buffer_type::resizable},
270+ {n_states, mr.main , vecmem::data::buffer_type::resizable}};
271+ vecmem::copy::event_type tracks_setup_event =
272+ copy.setup (track_states_buffer.tracks );
273+ vecmem::copy::event_type track_states_setup_event =
274+ copy.setup (track_states_buffer.states );
275+
276+ // Return early, if there are no tracks.
277+ if (n_tracks == 0 ) {
278+ tracks_setup_event->wait ();
279+ track_states_setup_event->wait ();
280+ return track_states_buffer;
281+ }
282+
283+ std::vector<unsigned int > seqs_sizes (candidate_sizes.size ());
284+ std::transform (candidate_sizes.begin (), candidate_sizes.end (),
285+ seqs_sizes.begin (), [&config](const unsigned int sz) {
286+ return std::max (sz * config.barcode_sequence_size_factor ,
287+ config.min_barcode_sequence_capacity );
288+ });
289+ vecmem::data::jagged_vector_buffer<detray::geometry::barcode> seqs_buffer{
290+ seqs_sizes, mr.main , mr.host , vecmem::data::buffer_type::resizable};
291+ copy.setup (seqs_buffer)->wait ();
292+
293+ // The execution range for the two kernels of the function.
294+ const Idx blocksPerGrid =
295+ (n_tracks + threadsPerBlock - 1 ) / threadsPerBlock;
296+ const auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
297+
298+ // Run the fitting, using the sorted parameter IDs.
299+ typename edm::track_fit_container<typename detector_t ::algebra_type>::view
300+ track_states_view{track_states_buffer.tracks ,
301+ track_states_buffer.states ,
302+ track_candidates_view.measurements };
303+ tracks_setup_event->wait ();
304+ track_states_setup_event->wait ();
305+
306+ ::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
307+ track_candidates_view, track_states_view);
308+ ::alpaka::wait (queue);
309+
310+ return kalman_fitting<detector_t , bfield_t >(
311+ det_view, field_view,
312+ typename edm::track_fit_container<
313+ typename detector_t ::algebra_type>::buffer{
314+ std::move (track_states_buffer.tracks ),
315+ std::move (track_states_buffer.states )},
316+ track_candidates_view.measurements , config, mr, copy, queue, true );
244317}
245318
246319} // namespace traccc::alpaka::details
0 commit comments