Skip to content

Commit 6cff356

Browse files
authored
exp: simplify measurement access (#1168)
* Simplify measurement lookup host * Simplify measurement access device * Apply to other backedns and remove kernel
1 parent 96b9f6d commit 6cff356

File tree

13 files changed

+98
-366
lines changed

13 files changed

+98
-366
lines changed

core/include/traccc/edm/measurement.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,20 @@ struct measurement_sort_comp {
9696
}
9797
};
9898

99+
struct measurement_sf_comp {
100+
template <typename sf_descriptor_t>
101+
TRACCC_HOST_DEVICE bool operator()(const sf_descriptor_t sf_desc,
102+
const measurement& rhs) {
103+
return sf_desc.barcode() < rhs.surface_link;
104+
}
105+
106+
template <typename sf_descriptor_t>
107+
TRACCC_HOST_DEVICE bool operator()(const measurement& lhs,
108+
const sf_descriptor_t sf_desc) {
109+
return lhs.surface_link < sf_desc.barcode();
110+
}
111+
};
112+
99113
struct measurement_equal_comp {
100114
TRACCC_HOST_DEVICE
101115
bool operator()(const measurement& lhs, const measurement& rhs) const {

core/include/traccc/finding/details/combinatorial_kalman_filter.hpp

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -87,41 +87,36 @@ combinatorial_kalman_filter(
8787
measurement_collection_types::const_device measurements{measurements_view};
8888

8989
// Check contiguity of the measurements
90-
assert(is_contiguous_on(measurement_module_projection(), measurements));
91-
92-
// Get copy of barcode uniques
93-
std::vector<measurement> uniques;
94-
uniques.resize(measurements.size());
95-
96-
std::vector<measurement>::iterator uniques_end =
97-
std::unique_copy(measurements.begin(), measurements.end(),
98-
uniques.begin(), measurement_equal_comp());
99-
const auto n_modules =
100-
static_cast<unsigned int>(uniques_end - uniques.begin());
101-
102-
// Get upper bounds of unique elements
103-
std::vector<unsigned int> upper_bounds;
104-
upper_bounds.reserve(n_modules);
105-
for (unsigned int i = 0; i < n_modules; i++) {
106-
measurement_collection_types::const_device::iterator up =
107-
std::upper_bound(measurements.begin(), measurements.end(),
108-
uniques[i], measurement_sort_comp());
109-
upper_bounds.push_back(
90+
assert(
91+
host::is_contiguous_on(measurement_module_projection(), measurements));
92+
93+
// Get index ranges in the measurement container per detector surface
94+
std::vector<unsigned int> meas_ranges;
95+
meas_ranges.reserve(det.surfaces().size());
96+
97+
for (const auto& sf_desc : det.surfaces()) {
98+
// Measurements can only be found on sensitive surfaces
99+
if (!sf_desc.is_sensitive()) {
100+
// Lower range index is the upper index of the previous range
101+
// This is guaranteed by the measurement sorting step
102+
const auto sf_idx{sf_desc.index()};
103+
const unsigned int lo{sf_idx == 0u ? 0u : meas_ranges[sf_idx - 1u]};
104+
105+
// Hand the upper index of the previous range through to assign
106+
// the lower index of the next valid range correctly
107+
meas_ranges.push_back(lo);
108+
continue;
109+
}
110+
111+
auto up = std::upper_bound(measurements.begin(), measurements.end(),
112+
sf_desc, measurement_sf_comp());
113+
meas_ranges.push_back(
110114
static_cast<unsigned int>(std::distance(measurements.begin(), up)));
111115
}
116+
112117
const measurement_collection_types::const_device::size_type n_meas =
113118
measurements.size();
114119

115-
// Get the number of measurements of each module
116-
std::vector<unsigned int> sizes(n_modules);
117-
std::adjacent_difference(upper_bounds.begin(), upper_bounds.end(),
118-
sizes.begin());
119-
120-
// Create barcode sequence
121-
std::vector<detray::geometry::barcode> barcodes(n_modules);
122-
std::transform(uniques.begin(), uniques_end, barcodes.begin(),
123-
[](const measurement& m) { return m.surface_link; });
124-
125120
std::vector<std::vector<candidate_link>> links;
126121
links.resize(config.max_track_candidates_per_track);
127122

@@ -225,48 +220,28 @@ combinatorial_kalman_filter(
225220
sf);
226221
}
227222

228-
// Get barcode and measurements range on surface
229-
const auto bcd = in_param.surface_link();
230-
assert(!bcd.is_invalid());
231-
std::pair<unsigned int, unsigned int> range;
232-
233-
// Find the corresponding index of bcd in barcode vector
234-
235-
const auto lo2 =
236-
std::lower_bound(barcodes.begin(), barcodes.end(), bcd);
237-
238-
const auto bcd_id = std::distance(barcodes.begin(), lo2);
239-
240-
if (lo2 == barcodes.begin()) {
241-
range.first = 0u;
242-
range.second = upper_bounds[static_cast<std::size_t>(bcd_id)];
243-
} else if (lo2 == barcodes.end()) {
244-
range.first = 0u;
245-
range.second = 0u;
246-
} else {
247-
range.first =
248-
upper_bounds[static_cast<std::size_t>(bcd_id - 1)];
249-
range.second = upper_bounds[static_cast<std::size_t>(bcd_id)];
250-
}
251-
252223
/*****************************************************************
253224
* Find tracks (CKF)
254225
*****************************************************************/
255226

227+
// Iterate over the measurements for this surface
228+
const auto sf_idx{sf.index()};
229+
const unsigned int lo{sf_idx == 0u ? 0u : meas_ranges[sf_idx - 1]};
230+
const unsigned int up{meas_ranges[sf_idx]};
231+
256232
std::vector<std::tuple<candidate_link,
257233
bound_track_parameters<algebra_type>>>
258234
best_links;
259235

260236
// Iterate over the measurements
261-
for (unsigned int item_id = range.first; item_id < range.second;
262-
item_id++) {
237+
for (unsigned int meas_id = lo; meas_id < up; meas_id++) {
263238

264239
// The measurement on surface to handle.
265-
const measurement& meas = measurements.at(item_id);
240+
const measurement& meas = measurements.at(meas_id);
266241

267242
// Create a standalone track state object.
268243
auto trk_state =
269-
edm::make_track_state<algebra_type>(measurements, item_id);
244+
edm::make_track_state<algebra_type>(measurements, meas_id);
270245

271246
const bool is_line = sf.template visit_mask<is_line_visitor>();
272247

@@ -284,7 +259,7 @@ combinatorial_kalman_filter(
284259
best_links.push_back(
285260
{{.step = step,
286261
.previous_candidate_idx = in_param_id,
287-
.meas_idx = item_id,
262+
.meas_idx = meas_id,
288263
.seed_idx = orig_param_id,
289264
.n_skipped = skip_counter,
290265
.chi2 = chi2,

device/alpaka/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp"
2626
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
2727
#include "traccc/finding/device/find_tracks.hpp"
28-
#include "traccc/finding/device/make_barcode_sequence.hpp"
2928
#include "traccc/finding/device/propagate_to_next_surface.hpp"
3029
#include "traccc/finding/device/remove_duplicates.hpp"
3130
#include "traccc/finding/finding_config.hpp"
@@ -40,19 +39,6 @@
4039
namespace traccc::alpaka::details {
4140
namespace kernels {
4241

43-
/// Alpaka kernel functor for @c traccc::device::make_barcode_sequence
44-
struct make_barcode_sequence {
45-
template <typename TAcc>
46-
ALPAKA_FN_ACC void operator()(
47-
TAcc const& acc,
48-
const device::make_barcode_sequence_payload payload) const {
49-
50-
const device::global_index_t globalThreadIdx =
51-
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
52-
device::make_barcode_sequence(globalThreadIdx, payload);
53-
}
54-
};
55-
5642
/// Alpaka kernel functor for @c traccc::device::apply_interaction
5743
template <typename detector_t>
5844
struct apply_interaction {
@@ -222,48 +208,21 @@ combinatorial_kalman_filter(
222208
const measurement_collection_types::const_view::size_type n_measurements =
223209
copy.get_size(measurements);
224210

225-
// Get copy of barcode uniques
226-
measurement_collection_types::buffer uniques_buffer{n_measurements,
227-
mr.main};
228-
copy.setup(uniques_buffer)->wait();
229-
measurement_collection_types::device uniques(uniques_buffer);
230-
231-
measurement_collection_types::device::iterator uniques_end =
232-
details::unique_copy(queue, mr, measurements.ptr(),
233-
measurements.ptr() + n_measurements,
234-
uniques.begin(), measurement_equal_comp());
235-
const unsigned int n_modules =
236-
static_cast<unsigned int>(uniques_end - uniques.begin());
237-
238-
// Get upper bounds of unique elements
239-
vecmem::data::vector_buffer<unsigned int> upper_bounds_buffer{n_modules,
240-
mr.main};
241-
copy.setup(upper_bounds_buffer)->wait();
242-
vecmem::device_vector<unsigned int> upper_bounds(upper_bounds_buffer);
243-
244-
details::upper_bound(queue, mr, measurements.ptr(),
245-
measurements.ptr() + n_measurements, uniques.begin(),
246-
uniques.begin() + n_modules, upper_bounds.begin(),
247-
measurement_sort_comp());
248-
249-
/*****************************************************************
250-
* Kernel1: Create barcode sequence
251-
*****************************************************************/
211+
// Access the detector view as a detector object
212+
detector_t device_det(det);
213+
const unsigned int n_surfaces{device_det.surfaces().size()};
252214

253-
vecmem::data::vector_buffer<detray::geometry::barcode> barcodes_buffer{
254-
n_modules, mr.main};
255-
copy.setup(barcodes_buffer)->wait();
215+
// Get upper bounds of measurement ranges per surface
216+
vecmem::data::vector_buffer<unsigned int> meas_ranges_buffer{n_surfaces,
217+
mr.main};
218+
copy.setup(meas_ranges_buffer)->ignore();
219+
vecmem::device_vector<unsigned int> measurement_ranges(meas_ranges_buffer);
256220

257-
{
258-
Idx blocksPerGrid =
259-
(barcodes_buffer.size() + threadsPerBlock - 1) / threadsPerBlock;
260-
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
261-
262-
::alpaka::exec<Acc>(queue, workDiv, kernels::make_barcode_sequence{},
263-
device::make_barcode_sequence_payload{
264-
uniques_buffer, barcodes_buffer});
265-
::alpaka::wait(queue);
266-
}
221+
// Get upper bounds of measurement ranges
222+
details::upper_bound(
223+
queue, mr, measurements.ptr(), measurements.ptr() + n_measurements,
224+
device_det.surfaces().begin(), device_det.surfaces().end(),
225+
measurement_ranges.begin(), measurement_sf_comp());
267226

268227
const unsigned int n_seeds = copy.get_size(seeds);
269228

@@ -384,8 +343,7 @@ combinatorial_kalman_filter(
384343
.in_params_view = in_params_buffer,
385344
.in_params_liveness_view = param_liveness_buffer,
386345
.n_in_params = n_in_params,
387-
.barcodes_view = barcodes_buffer,
388-
.upper_bounds_view = upper_bounds_buffer,
346+
.measurement_ranges_view = meas_ranges_buffer,
389347
.links_view = links_buffer,
390348
.prev_links_idx =
391349
(step == 0 ? 0 : step_to_link_idx_map[step - 1]),

device/common/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ traccc_add_library( traccc_device_common device_common TYPE INTERFACE
6262
"include/traccc/finding/device/build_tracks.hpp"
6363
"include/traccc/finding/device/find_tracks.hpp"
6464
"include/traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
65-
"include/traccc/finding/device/make_barcode_sequence.hpp"
6665
"include/traccc/finding/device/propagate_to_next_surface.hpp"
6766
"include/traccc/finding/device/impl/apply_interaction.ipp"
6867
"include/traccc/finding/device/impl/build_tracks.ipp"
6968
"include/traccc/finding/device/impl/find_tracks.ipp"
7069
"include/traccc/finding/device/impl/fill_finding_propagation_sort_keys.ipp"
71-
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7270
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
7371
# Track fitting funtions(s).
7472
"include/traccc/fitting/device/fit.hpp"

device/common/include/traccc/finding/device/find_tracks.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,9 @@ struct find_tracks_payload {
6060
unsigned int n_in_params;
6161

6262
/**
63-
* @brief View object to the vector of barcodes for each measurement
63+
* @brief View object to the vector of measurement index ranges per surface
6464
*/
65-
vecmem::data::vector_view<const detray::geometry::barcode> barcodes_view;
66-
67-
/**
68-
* @brief View object to the vector of upper bounds of measurement indices
69-
* per surface
70-
*/
71-
vecmem::data::vector_view<const unsigned int> upper_bounds_view;
65+
vecmem::data::vector_view<const unsigned int> measurement_ranges_view;
7266

7367
/**
7468
* @brief View object to the link vector

device/common/include/traccc/finding/device/impl/find_tracks.ipp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,8 @@ TRACCC_HOST_DEVICE inline void find_tracks(
6161
vecmem::device_vector<candidate_link> tmp_links(payload.tmp_links_view);
6262
bound_track_parameters_collection_types::device tmp_params(
6363
payload.tmp_params_view);
64-
vecmem::device_vector<const detray::geometry::barcode> barcodes(
65-
payload.barcodes_view);
66-
vecmem::device_vector<const unsigned int> upper_bounds(
67-
payload.upper_bounds_view);
64+
vecmem::device_vector<const unsigned int> meas_ranges(
65+
payload.measurement_ranges_view);
6866
vecmem::device_vector<unsigned int> tips(payload.tips_view);
6967
vecmem::device_vector<unsigned int> tip_lengths(payload.tip_lengths_view);
7068
vecmem::device_vector<unsigned int> n_tracks_per_seed(
@@ -102,31 +100,16 @@ TRACCC_HOST_DEVICE inline void find_tracks(
102100
* Get the barcode of this thread's parameters, then find the first
103101
* measurement that matches it.
104102
*/
105-
const auto bcd = in_params.at(in_param_id).surface_link();
106-
const auto lo = thrust::lower_bound(thrust::seq, barcodes.begin(),
107-
barcodes.end(), bcd);
103+
const unsigned int sf_idx{
104+
in_params.at(in_param_id).surface_link().index()};
105+
init_meas = sf_idx == 0u ? 0u : meas_ranges[sf_idx - 1];
106+
num_meas = meas_ranges[sf_idx] - init_meas;
108107

109108
/*
110109
* If we cannot find any corresponding measurements, set the number of
111110
* measurements to zero.
112111
*/
113-
if (lo == barcodes.end() || *lo != bcd) {
114-
init_meas = 0;
115-
}
116-
/*
117-
* If measurements are found, use the previously (outside this kernel)
118-
* computed upper bound array to compute the range of measurements for
119-
* this thread.
120-
*/
121-
else {
122-
const vecmem::device_vector<const unsigned int>::size_type bcd_id =
123-
static_cast<
124-
vecmem::device_vector<const unsigned int>::size_type>(
125-
std::distance(barcodes.begin(), lo));
126-
127-
init_meas = lo == barcodes.begin() ? 0u : upper_bounds[bcd_id - 1];
128-
num_meas = upper_bounds[bcd_id] - init_meas;
129-
}
112+
init_meas = (num_meas == 0u) ? 0u : init_meas;
130113
}
131114

132115
/*

device/common/include/traccc/finding/device/impl/make_barcode_sequence.ipp

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)