Skip to content

Commit b81e99b

Browse files
committed
Adapted traccc::alpaka and traccc::sycl to the SoA track fit EDM.
1 parent 2ee7c61 commit b81e99b

File tree

4 files changed

+65
-47
lines changed

4 files changed

+65
-47
lines changed

device/alpaka/include/traccc/alpaka/fitting/kalman_fitting_algorithm.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Project include(s).
1414
#include "traccc/bfield/magnetic_field.hpp"
1515
#include "traccc/edm/track_candidate_container.hpp"
16-
#include "traccc/edm/track_state.hpp"
16+
#include "traccc/edm/track_fit_container.hpp"
1717
#include "traccc/fitting/fitting_config.hpp"
1818
#include "traccc/geometry/detector.hpp"
1919
#include "traccc/utils/algorithm.hpp"
@@ -30,10 +30,10 @@ namespace traccc::alpaka {
3030

3131
/// Kalman filter based track fitting algorithm
3232
class kalman_fitting_algorithm
33-
: public algorithm<track_state_container_types::buffer(
33+
: public algorithm<edm::track_fit_container<default_algebra>::buffer(
3434
const default_detector::view&, const magnetic_field&,
3535
const edm::track_candidate_container<default_algebra>::const_view&)>,
36-
public algorithm<track_state_container_types::buffer(
36+
public algorithm<edm::track_fit_container<default_algebra>::buffer(
3737
const telescope_detector::view&, const magnetic_field&,
3838
const edm::track_candidate_container<default_algebra>::const_view&)>,
3939
public messaging {
@@ -42,7 +42,7 @@ class kalman_fitting_algorithm
4242
/// Configuration type
4343
using config_type = fitting_config;
4444
/// Output type
45-
using output_type = track_state_container_types::buffer;
45+
using output_type = edm::track_fit_container<default_algebra>::buffer;
4646

4747
/// Constructor with the algorithm's configuration
4848
///

device/alpaka/src/fitting/kalman_fitting.hpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
// Project include(s).
1515
#include "traccc/edm/device/sort_key.hpp"
1616
#include "traccc/edm/track_candidate_container.hpp"
17-
#include "traccc/edm/track_state.hpp"
17+
#include "traccc/edm/track_fit_container.hpp"
1818
#include "traccc/fitting/details/kalman_fitting_types.hpp"
1919
#include "traccc/fitting/device/fill_fitting_sort_keys.hpp"
2020
#include "traccc/fitting/device/fit.hpp"
@@ -55,14 +55,14 @@ struct fit_prelude {
5555
vecmem::data::vector_view<const unsigned int> param_ids_view,
5656
edm::track_candidate_container<default_algebra>::const_view
5757
track_candidates_view,
58-
track_state_container_types::view track_states_view,
58+
edm::track_fit_container<default_algebra>::view track_states_view,
5959
vecmem::data::vector_view<unsigned int> param_liveness_view) const {
6060

6161
const device::global_index_t globalThreadIdx =
6262
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
63-
device::fit_prelude(globalThreadIdx, param_ids_view,
64-
track_candidates_view, track_states_view,
65-
param_liveness_view);
63+
device::fit_prelude<default_algebra>(
64+
globalThreadIdx, param_ids_view, track_candidates_view,
65+
track_states_view, param_liveness_view);
6666
}
6767
};
6868

@@ -112,7 +112,8 @@ struct fit_backward {
112112
/// @return A container of the fitted track states
113113
///
114114
template <typename detector_t, typename bfield_t>
115-
track_state_container_types::buffer kalman_fitting(
115+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
116+
kalman_fitting(
116117
const typename detector_t::view_type& det_view, const bfield_t& field_view,
117118
const typename edm::track_candidate_container<
118119
typename detector_t::algebra_type>::const_view& track_candidates_view,
@@ -130,21 +131,24 @@ track_state_container_types::buffer kalman_fitting(
130131
// Get the sizes of the track candidates in each track.
131132
const std::vector<unsigned int> candidate_sizes =
132133
copy.get_sizes(track_candidates_view.tracks);
134+
const unsigned int n_states =
135+
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
133136

134137
// Create the result buffer.
135-
track_state_container_types::buffer track_states_buffer{
136-
{n_tracks, mr.main},
137-
{candidate_sizes, mr.main, mr.host,
138-
vecmem::data::buffer_type::resizable}};
139-
vecmem::copy::event_type track_states_headers_setup_event =
140-
copy.setup(track_states_buffer.headers);
141-
vecmem::copy::event_type track_states_items_setup_event =
142-
copy.setup(track_states_buffer.items);
138+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
139+
track_states_buffer{
140+
{candidate_sizes, mr.main, mr.host,
141+
vecmem::data::buffer_type::resizable},
142+
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
143+
vecmem::copy::event_type tracks_setup_event =
144+
copy.setup(track_states_buffer.tracks);
145+
vecmem::copy::event_type track_states_setup_event =
146+
copy.setup(track_states_buffer.states);
143147

144148
// Return early, if there are no tracks.
145149
if (n_tracks == 0) {
146-
track_states_headers_setup_event->wait();
147-
track_states_items_setup_event->wait();
150+
tracks_setup_event->wait();
151+
track_states_setup_event->wait();
148152
return track_states_buffer;
149153
}
150154

@@ -193,9 +197,12 @@ track_state_container_types::buffer kalman_fitting(
193197
param_ids_device.begin());
194198

195199
// Run the fitting, using the sorted parameter IDs.
196-
track_state_container_types::view track_states_view = track_states_buffer;
197-
track_states_headers_setup_event->wait();
198-
track_states_items_setup_event->wait();
200+
typename edm::track_fit_container<typename detector_t::algebra_type>::view
201+
track_states_view{track_states_buffer.tracks,
202+
track_states_buffer.states,
203+
track_candidates_view.measurements};
204+
tracks_setup_event->wait();
205+
track_states_setup_event->wait();
199206

200207
::alpaka::exec<Acc>(queue, workDiv, kernels::fit_prelude{},
201208
vecmem::get_data(param_ids_buffer),
@@ -210,7 +217,7 @@ track_state_container_types::buffer kalman_fitting(
210217
.field_data = field_view,
211218
.param_ids_view = param_ids_buffer,
212219
.param_liveness_view = param_liveness_buffer,
213-
.track_states_view = track_states_view,
220+
.tracks_view = track_states_view,
214221
.barcodes_view = seqs_buffer};
215222
// Now copy it to device memory.
216223
vecmem::data::vector_buffer<device::fit_payload<fitter_t>> device_payload(

device/sycl/include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Project include(s).
1414
#include "traccc/bfield/magnetic_field.hpp"
1515
#include "traccc/edm/track_candidate_container.hpp"
16-
#include "traccc/edm/track_state.hpp"
16+
#include "traccc/edm/track_fit_container.hpp"
1717
#include "traccc/fitting/fitting_config.hpp"
1818
#include "traccc/geometry/detector.hpp"
1919
#include "traccc/utils/algorithm.hpp"
@@ -30,10 +30,10 @@ namespace traccc::sycl {
3030

3131
/// Kalman filter based track fitting algorithm
3232
class kalman_fitting_algorithm
33-
: public algorithm<track_state_container_types::buffer(
33+
: public algorithm<edm::track_fit_container<default_algebra>::buffer(
3434
const default_detector::view&, const magnetic_field&,
3535
const edm::track_candidate_container<default_algebra>::const_view&)>,
36-
public algorithm<track_state_container_types::buffer(
36+
public algorithm<edm::track_fit_container<default_algebra>::buffer(
3737
const telescope_detector::view&, const magnetic_field&,
3838
const edm::track_candidate_container<default_algebra>::const_view&)>,
3939
public messaging {
@@ -42,7 +42,7 @@ class kalman_fitting_algorithm
4242
/// Configuration type
4343
using config_type = fitting_config;
4444
/// Output type
45-
using output_type = track_state_container_types::buffer;
45+
using output_type = edm::track_fit_container<default_algebra>::buffer;
4646

4747
/// Constructor with the algorithm's configuration
4848
///

device/sycl/src/fitting/kalman_fitting.hpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// Project include(s).
1616
#include "traccc/edm/device/sort_key.hpp"
1717
#include "traccc/edm/track_candidate_container.hpp"
18-
#include "traccc/edm/track_state.hpp"
18+
#include "traccc/edm/track_fit_container.hpp"
1919
#include "traccc/fitting/details/kalman_fitting_types.hpp"
2020
#include "traccc/fitting/device/fill_fitting_sort_keys.hpp"
2121
#include "traccc/fitting/device/fit.hpp"
@@ -31,6 +31,9 @@
3131
// SYCL include(s).
3232
#include <sycl/sycl.hpp>
3333

34+
// System include(s).
35+
#include <numeric>
36+
3437
namespace traccc::sycl {
3538
namespace kernels {
3639

@@ -64,7 +67,8 @@ namespace details {
6467
/// @return A container of the fitted track states
6568
///
6669
template <typename kernel_t, typename detector_t, typename bfield_t>
67-
track_state_container_types::buffer kalman_fitting(
70+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
71+
kalman_fitting(
6872
const typename detector_t::view_type& det_view, const bfield_t& field_view,
6973
const typename edm::track_candidate_container<
7074
typename detector_t::algebra_type>::const_view& track_candidates_view,
@@ -79,21 +83,24 @@ track_state_container_types::buffer kalman_fitting(
7983
// Get the sizes of the track candidates in each track.
8084
const std::vector<unsigned int> candidate_sizes =
8185
copy.get_sizes(track_candidates_view.tracks);
86+
const unsigned int n_states =
87+
std::accumulate(candidate_sizes.begin(), candidate_sizes.end(), 0u);
8288

8389
// Create the result buffer.
84-
track_state_container_types::buffer track_states_buffer{
85-
{n_tracks, mr.main},
86-
{candidate_sizes, mr.main, mr.host,
87-
vecmem::data::buffer_type::resizable}};
88-
vecmem::copy::event_type track_states_headers_setup_event =
89-
copy.setup(track_states_buffer.headers);
90-
vecmem::copy::event_type track_states_items_setup_event =
91-
copy.setup(track_states_buffer.items);
90+
typename edm::track_fit_container<typename detector_t::algebra_type>::buffer
91+
track_states_buffer{
92+
{candidate_sizes, mr.main, mr.host,
93+
vecmem::data::buffer_type::resizable},
94+
{n_states, mr.main, vecmem::data::buffer_type::resizable}};
95+
vecmem::copy::event_type tracks_setup_event =
96+
copy.setup(track_states_buffer.tracks);
97+
vecmem::copy::event_type track_states_setup_event =
98+
copy.setup(track_states_buffer.states);
9299

93100
// Return early, if there are no tracks.
94101
if (n_tracks == 0) {
95-
track_states_headers_setup_event->wait();
96-
track_states_items_setup_event->wait();
102+
tracks_setup_event->wait();
103+
track_states_setup_event->wait();
97104
return track_states_buffer;
98105
}
99106

@@ -149,9 +156,12 @@ track_state_container_types::buffer kalman_fitting(
149156
param_ids_device.begin());
150157

151158
// Run the fitting, using the sorted parameter IDs.
152-
track_state_container_types::view track_states_view = track_states_buffer;
153-
track_states_headers_setup_event->wait();
154-
track_states_items_setup_event->wait();
159+
typename edm::track_fit_container<typename detector_t::algebra_type>::view
160+
track_states_view{track_states_buffer.tracks,
161+
track_states_buffer.states,
162+
track_candidates_view.measurements};
163+
tracks_setup_event->wait();
164+
track_states_setup_event->wait();
155165

156166
queue
157167
.submit([&](::sycl::handler& h) {
@@ -160,9 +170,10 @@ track_state_container_types::buffer kalman_fitting(
160170
track_candidates_view, track_states_view,
161171
param_liveness_view = vecmem::get_data(
162172
param_liveness_buffer)](::sycl::nd_item<1> item) {
163-
device::fit_prelude(details::global_index(item),
164-
param_ids_view, track_candidates_view,
165-
track_states_view, param_liveness_view);
173+
device::fit_prelude<typename detector_t::algebra_type>(
174+
details::global_index(item), param_ids_view,
175+
track_candidates_view, track_states_view,
176+
param_liveness_view);
166177
});
167178
})
168179
.wait_and_throw();
@@ -173,7 +184,7 @@ track_state_container_types::buffer kalman_fitting(
173184
.field_data = field_view,
174185
.param_ids_view = param_ids_buffer,
175186
.param_liveness_view = param_liveness_buffer,
176-
.track_states_view = track_states_view,
187+
.tracks_view = track_states_view,
177188
.barcodes_view = seqs_buffer};
178189

179190
for (std::size_t i = 0; i < config.n_iterations; ++i) {

0 commit comments

Comments
 (0)