1414// Project include(s).
1515#include " traccc/edm/device/sort_key.hpp"
1616#include " traccc/edm/track_candidate_container.hpp"
17- #include " traccc/edm/track_state .hpp"
17+ #include " traccc/edm/track_fit_container .hpp"
1818#include " traccc/fitting/details/kalman_fitting_types.hpp"
1919#include " traccc/fitting/device/fill_fitting_sort_keys.hpp"
2020#include " traccc/fitting/device/fit.hpp"
@@ -55,14 +55,14 @@ struct fit_prelude {
5555 vecmem::data::vector_view<const unsigned int > param_ids_view,
5656 edm::track_candidate_container<default_algebra>::const_view
5757 track_candidates_view,
58- track_state_container_types ::view track_states_view,
58+ edm::track_fit_container<default_algebra> ::view track_states_view,
5959 vecmem::data::vector_view<unsigned int > param_liveness_view) const {
6060
6161 const device::global_index_t globalThreadIdx =
6262 ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0 ];
63- device::fit_prelude (globalThreadIdx, param_ids_view,
64- track_candidates_view, track_states_view ,
65- param_liveness_view);
63+ device::fit_prelude<default_algebra>(
64+ globalThreadIdx, param_ids_view, track_candidates_view ,
65+ track_states_view, param_liveness_view);
6666 }
6767};
6868
@@ -112,7 +112,8 @@ struct fit_backward {
112112// / @return A container of the fitted track states
113113// /
114114template <typename detector_t , typename bfield_t >
115- track_state_container_types::buffer kalman_fitting (
115+ typename edm::track_fit_container<typename detector_t ::algebra_type>::buffer
116+ kalman_fitting (
116117 const typename detector_t ::view_type& det_view, const bfield_t & field_view,
117118 const typename edm::track_candidate_container<
118119 typename detector_t ::algebra_type>::const_view& track_candidates_view,
@@ -130,21 +131,24 @@ track_state_container_types::buffer kalman_fitting(
130131 // Get the sizes of the track candidates in each track.
131132 const std::vector<unsigned int > candidate_sizes =
132133 copy.get_sizes (track_candidates_view.tracks );
134+ const unsigned int n_states =
135+ std::accumulate (candidate_sizes.begin (), candidate_sizes.end (), 0u );
133136
134137 // Create the result buffer.
135- track_state_container_types::buffer track_states_buffer{
136- {n_tracks, mr.main },
137- {candidate_sizes, mr.main , mr.host ,
138- vecmem::data::buffer_type::resizable}};
139- vecmem::copy::event_type track_states_headers_setup_event =
140- copy.setup (track_states_buffer.headers );
141- vecmem::copy::event_type track_states_items_setup_event =
142- copy.setup (track_states_buffer.items );
138+ typename edm::track_fit_container<typename detector_t ::algebra_type>::buffer
139+ track_states_buffer{
140+ {candidate_sizes, mr.main , mr.host ,
141+ vecmem::data::buffer_type::resizable},
142+ {n_states, mr.main , vecmem::data::buffer_type::resizable}};
143+ vecmem::copy::event_type tracks_setup_event =
144+ copy.setup (track_states_buffer.tracks );
145+ vecmem::copy::event_type track_states_setup_event =
146+ copy.setup (track_states_buffer.states );
143147
144148 // Return early, if there are no tracks.
145149 if (n_tracks == 0 ) {
146- track_states_headers_setup_event ->wait ();
147- track_states_items_setup_event ->wait ();
150+ tracks_setup_event ->wait ();
151+ track_states_setup_event ->wait ();
148152 return track_states_buffer;
149153 }
150154
@@ -193,9 +197,12 @@ track_state_container_types::buffer kalman_fitting(
193197 param_ids_device.begin ());
194198
195199 // Run the fitting, using the sorted parameter IDs.
196- track_state_container_types::view track_states_view = track_states_buffer;
197- track_states_headers_setup_event->wait ();
198- track_states_items_setup_event->wait ();
200+ typename edm::track_fit_container<typename detector_t ::algebra_type>::view
201+ track_states_view{track_states_buffer.tracks ,
202+ track_states_buffer.states ,
203+ track_candidates_view.measurements };
204+ tracks_setup_event->wait ();
205+ track_states_setup_event->wait ();
199206
200207 ::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
201208 vecmem::get_data (param_ids_buffer),
@@ -210,7 +217,7 @@ track_state_container_types::buffer kalman_fitting(
210217 .field_data = field_view,
211218 .param_ids_view = param_ids_buffer,
212219 .param_liveness_view = param_liveness_buffer,
213- .track_states_view = track_states_view,
220+ .tracks_view = track_states_view,
214221 .barcodes_view = seqs_buffer};
215222 // Now copy it to device memory.
216223 vecmem::data::vector_buffer<device::fit_payload<fitter_t >> device_payload (
0 commit comments