Skip to content

Commit 5cfd02a

Browse files
committed
Embed bottom and middle spacepoint in triplet
Right now, we store only the top spacepoint in the device triplet which is technically all that is required, but this makes it very difficult and time-consuming to retrieve the bottom and middle spacepoint. This commit embeds the locations of those spacepoints in the triplet, making life easier for a lot of the planned seeding changes.
1 parent 327189d commit 5cfd02a

File tree

10 files changed

+64
-75
lines changed

10 files changed

+64
-75
lines changed

core/include/traccc/seeding/seed_selecting_helper.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,12 @@ struct seed_selecting_helper {
7070
/// @param seed current seed to possibly cut
7171
///
7272
/// @return boolean value
73+
template <typename spacepoint_type>
7374
static TRACCC_HOST_DEVICE bool cut_per_middle_sp(
74-
const seedfilter_config& filter_config,
75-
const edm::spacepoint_collection::const_device& spacepoints,
76-
const details::spacepoint_grid_types::const_device& grid,
77-
const triplet& seed) {
75+
const seedfilter_config& filter_config, const spacepoint_type& spB,
76+
const scalar weight) {
7877

79-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
80-
spacepoints.at(grid.bin(seed.sp1.bin_idx)[seed.sp1.sp_idx]);
81-
82-
return (seed.weight > filter_config.seed_min_weight ||
78+
return (weight > filter_config.seed_min_weight ||
8379
spB.radius() > filter_config.spB_min_radius);
8480
}
8581
};

core/src/seeding/seed_filtering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,14 @@ void seed_filtering::operator()(
9191
std::min(triplets_passing_single_seed_cuts.size(),
9292
m_filter_config.max_triplets_per_spM);
9393
for (std::size_t i = 1; i < itLength; ++i) {
94+
const traccc::details::spacepoint_grid_types::const_device
95+
sp_grid_accessor(sp_grid_data);
96+
const auto& this_seed = triplets_passing_single_seed_cuts[i].get();
9497
if (seed_selecting_helper::cut_per_middle_sp(
95-
m_filter_config, spacepoints, sp_grid_data,
96-
triplets_passing_single_seed_cuts[i])) {
98+
m_filter_config,
99+
spacepoints.at(sp_grid_accessor.bin(
100+
this_seed.sp1.bin_idx)[this_seed.sp1.sp_idx]),
101+
this_seed.weight)) {
97102
triplets_passing_final_cuts.push_back(
98103
triplets_passing_single_seed_cuts[i]);
99104
}

device/alpaka/src/seeding/seed_finding.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ struct UpdateTripletWeights {
134134
ALPAKA_FN_ACC void operator()(
135135
TAcc const& acc, seedfilter_config filter_config,
136136
edm::spacepoint_collection::const_view spacepoints,
137-
traccc::details::spacepoint_grid_types::const_view sp_grid,
138137
device::triplet_counter_spM_collection_types::const_view spM_tc,
139138
device::triplet_counter_collection_types::const_view midBot_tc,
140139
device::device_triplet_collection_types::view triplet_view) const {
@@ -151,8 +150,8 @@ struct UpdateTripletWeights {
151150
scalar* dataPos = &data[localThreadIdx * filter_config.compatSeedLimit];
152151

153152
device::update_triplet_weights(globalThreadIdx, filter_config,
154-
spacepoints, sp_grid, spM_tc, midBot_tc,
155-
dataPos, triplet_view);
153+
spacepoints, spM_tc, midBot_tc, dataPos,
154+
triplet_view);
156155
}
157156
};
158157

@@ -174,10 +173,11 @@ struct SelectSeeds {
174173

175174
// Array for temporary storage of quality parameters for comparing
176175
// triplets within weight updating kernel
177-
triplet* const data = ::alpaka::getDynSharedMem<triplet>(acc);
176+
device::device_triplet* const data =
177+
::alpaka::getDynSharedMem<device::device_triplet>(acc);
178178

179179
// Each thread uses max_triplets_per_spM elements of the array
180-
triplet* dataPos =
180+
device::device_triplet* dataPos =
181181
&data[localThreadIdx * filter_config.max_triplets_per_spM];
182182

183183
device::select_seeds(globalThreadIdx, filter_config, spacepoints,
@@ -360,7 +360,7 @@ edm::seed_collection::buffer seed_finding::operator()(
360360

361361
// Update the weights of all spacepoint triplets.
362362
::alpaka::exec<Acc>(queue, workDiv, kernels::UpdateTripletWeights{},
363-
m_seedfilter_config, spacepoints_view, g2_view,
363+
m_seedfilter_config, spacepoints_view,
364364
vecmem::get_data(triplet_counter_spM_buffer),
365365
vecmem::get_data(triplet_counter_midBot_buffer),
366366
vecmem::get_data(triplet_buffer));
@@ -418,7 +418,7 @@ struct BlockSharedMemDynSizeBytes<traccc::alpaka::kernels::SelectSeeds, TAcc> {
418418
) -> std::size_t {
419419
return static_cast<std::size_t>(filter_config.max_triplets_per_spM *
420420
blockThreadExtent.prod()) *
421-
sizeof(traccc::triplet);
421+
sizeof(traccc::device::device_triplet);
422422
}
423423
};
424424

device/common/include/traccc/edm/device/device_triplet.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace traccc::device {
1717
/// Triplets of bottom, middle and top spacepoints
1818
struct device_triplet {
1919
// top spacepoint location in internal spacepoint container
20-
sp_location spT;
20+
unsigned int spB, spM, spT;
2121

2222
using link_type = device::triplet_counter_collection_types::host::size_type;
2323
/// Link to triplet counter where the middle and bottom spacepoints are

device/common/include/traccc/seeding/device/impl/find_triplets.ipp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ inline void find_triplets(
5656
const sp_location spB_loc = mid_bot_counter.spB;
5757

5858
// middle spacepoint
59+
const unsigned int spM_idx = sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx];
5960
const edm::spacepoint_collection::const_device::const_proxy_type spM =
60-
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
61+
spacepoints.at(spM_idx);
6162

6263
// bottom spacepoint
64+
const unsigned int spB_idx = sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx];
6365
const edm::spacepoint_collection::const_device::const_proxy_type spB =
64-
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
66+
spacepoints.at(spB_idx);
6567

6668
// Set up the device result collection
6769
device_triplet_collection_types::device triplets(triplet_view);
@@ -95,8 +97,10 @@ inline void find_triplets(
9597
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
9698
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
9799

100+
const unsigned int spT_idx =
101+
sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx];
98102
const edm::spacepoint_collection::const_device::const_proxy_type spT =
99-
spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
103+
spacepoints.at(spT_idx);
100104

101105
// Apply the conformal transformation to middle-top doublet
102106
const traccc::lin_circle lt =
@@ -110,7 +114,7 @@ inline void find_triplets(
110114

111115
// Add triplet to jagged vector
112116
triplets.at(posTriplets++) = device_triplet(
113-
{spT_loc, globalIndex, curvature,
117+
{spB_idx, spM_idx, spT_idx, globalIndex, curvature,
114118
-impact_parameter * filter_config.impactWeightFactor,
115119
lb.Zo()});
116120
}

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace traccc::device {
2020
namespace details {
2121
// Finding minimum element algorithm
2222
template <typename Comparator>
23-
TRACCC_HOST_DEVICE std::size_t min_elem(const triplet* arr,
23+
TRACCC_HOST_DEVICE std::size_t min_elem(const device_triplet* arr,
2424
const std::size_t begin_idx,
2525
const std::size_t end_idx,
2626
Comparator comp) {
@@ -38,11 +38,11 @@ TRACCC_HOST_DEVICE std::size_t min_elem(const triplet* arr,
3838

3939
// Sorting algorithm for sorting seeds in the local memory
4040
template <typename Comparator>
41-
TRACCC_HOST_DEVICE void insertionSort(triplet* arr,
41+
TRACCC_HOST_DEVICE void insertionSort(device_triplet* arr,
4242
const unsigned int begin_idx,
4343
const unsigned int n, Comparator comp) {
4444
int j = 0;
45-
triplet key = arr[begin_idx];
45+
device_triplet key = arr[begin_idx];
4646
for (unsigned int i = 0; i < n; ++i) {
4747
key = arr[begin_idx + i];
4848
j = static_cast<int>(i) - 1;
@@ -66,7 +66,7 @@ inline void select_seeds(
6666
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6767
const triplet_counter_collection_types::const_view& tc_view,
6868
const device_triplet_collection_types::const_view& triplet_view,
69-
triplet* data, edm::seed_collection::view seed_view) {
69+
device_triplet* data, edm::seed_collection::view seed_view) {
7070

7171
// Check if anything needs to be done.
7272
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
@@ -89,8 +89,9 @@ inline void select_seeds(
8989
// Current work item = middle spacepoint
9090
const triplet_counter_spM spM_counter = triplet_counts_spM.at(globalIndex);
9191
const sp_location spM_loc = spM_counter.spM;
92+
const unsigned int spM_idx = sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx];
9293
const edm::spacepoint_collection::const_device::const_proxy_type spM =
93-
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
94+
spacepoints.at(spM_idx);
9495

9596
// Number of triplets added for this spM
9697
unsigned int n_triplets_per_spM = 0;
@@ -102,14 +103,12 @@ inline void select_seeds(
102103
device_triplet aTriplet = triplets[i];
103104

104105
// spacepoints bottom and top for this triplet
105-
const sp_location spB_loc =
106-
triplet_counts.at(static_cast<unsigned int>(aTriplet.counter_link))
107-
.spB;
108-
const sp_location spT_loc = aTriplet.spT;
106+
const unsigned int spB_idx = aTriplet.spB;
109107
const edm::spacepoint_collection::const_device::const_proxy_type spB =
110-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
108+
spacepoints.at(spB_idx);
109+
const unsigned int spT_idx = aTriplet.spT;
111110
const edm::spacepoint_collection::const_device::const_proxy_type spT =
112-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
111+
spacepoints.at(spT_idx);
113112

114113
// update weight of triplet
115114
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
@@ -125,61 +124,54 @@ inline void select_seeds(
125124
// the triplet with the lowest weight is removed
126125
if (n_triplets_per_spM >= filter_config.max_triplets_per_spM) {
127126

128-
const std::size_t min_index =
129-
details::min_elem(data, 0, filter_config.max_triplets_per_spM,
130-
[](const triplet lhs, const triplet rhs) {
131-
return lhs.weight > rhs.weight;
132-
});
127+
const std::size_t min_index = details::min_elem(
128+
data, 0, filter_config.max_triplets_per_spM,
129+
[](const device_triplet& lhs, const device_triplet& rhs) {
130+
return lhs.weight > rhs.weight;
131+
});
133132

134133
const scalar& min_weight = data[min_index].weight;
135134

136135
if (aTriplet.weight > min_weight) {
137-
data[min_index] = {spB_loc, spM_loc,
138-
spT_loc, aTriplet.curvature,
139-
aTriplet.weight, aTriplet.z_vertex};
136+
data[min_index] = aTriplet;
140137
}
141138
}
142139

143140
// if the number of good triplets is below the threshold, add
144141
// the current triplet to the array
145142
else if (n_triplets_per_spM < filter_config.max_triplets_per_spM) {
146-
data[n_triplets_per_spM] = {spB_loc, spM_loc,
147-
spT_loc, aTriplet.curvature,
148-
aTriplet.weight, aTriplet.z_vertex};
143+
data[n_triplets_per_spM] = aTriplet;
149144
n_triplets_per_spM++;
150145
}
151146
}
152147

153148
// sort the triplets per spM
154149
details::insertionSort(
155150
data, 0, n_triplets_per_spM,
156-
traccc::details::triplet_sorter{spacepoints, sp_device});
151+
[](const device_triplet& lhs, const device_triplet& rhs) {
152+
return lhs.weight > rhs.weight;
153+
});
157154

158155
// the number of good seed per compatible middle spacepoint
159156
unsigned int n_seeds_per_spM = 0;
160157

161158
// iterate over the good triplets for final selection of seeds
162159
for (unsigned int i = 0; i < n_triplets_per_spM; ++i) {
163-
const triplet& aTriplet = data[i];
164-
const sp_location& spB_loc = aTriplet.sp1;
165-
const sp_location& spT_loc = aTriplet.sp3;
160+
const device_triplet& aTriplet = data[i];
166161

167162
// if the number of seeds reaches the threshold, break
168163
if (n_seeds_per_spM >= filter_config.maxSeedsPerSpM + 1) {
169164
break;
170165
}
171166

172167
// check if it is a good triplet
173-
if (seed_selecting_helper::cut_per_middle_sp(filter_config, spacepoints,
174-
sp_device, aTriplet) ||
168+
if (seed_selecting_helper::cut_per_middle_sp(
169+
filter_config, spacepoints.at(aTriplet.spB), aTriplet.weight) ||
175170
n_seeds_per_spM == 0) {
176171

177172
n_seeds_per_spM++;
178173

179-
seeds_device.push_back(
180-
{sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx],
181-
sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx],
182-
sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]});
174+
seeds_device.push_back({aTriplet.spB, aTriplet.spM, aTriplet.spT});
183175
}
184176
}
185177
}

device/common/include/traccc/seeding/device/impl/update_triplet_weights.ipp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ TRACCC_HOST_DEVICE
1919
inline void update_triplet_weights(
2020
const global_index_t globalIndex, const seedfilter_config& filter_config,
2121
const edm::spacepoint_collection::const_view& spacepoints_view,
22-
const traccc::details::spacepoint_grid_types::const_view& sp_view,
2322
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
2423
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
2524
device_triplet_collection_types::view triplet_view) {
@@ -33,7 +32,6 @@ inline void update_triplet_weights(
3332
// Set up the device containers
3433
const edm::spacepoint_collection::const_device spacepoints{
3534
spacepoints_view};
36-
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
3735
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
3836
spM_tc_view);
3937
const triplet_counter_collection_types::const_device triplet_counts(
@@ -42,11 +40,8 @@ inline void update_triplet_weights(
4240
// Current work item
4341
device_triplet this_triplet = triplets.at(globalIndex);
4442

45-
const sp_location& spT_idx = this_triplet.spT;
46-
4743
const edm::spacepoint_collection::const_device::const_proxy_type
48-
current_spT =
49-
spacepoints.at(sp_grid.bin(spT_idx.bin_idx)[spT_idx.sp_idx]);
44+
current_spT = spacepoints.at(this_triplet.spT);
5045

5146
const scalar currentTop_r = current_spT.radius();
5247

@@ -82,10 +77,8 @@ inline void update_triplet_weights(
8277
}
8378

8479
const device_triplet other_triplet = triplets[i];
85-
const sp_location other_spT_idx = other_triplet.spT;
8680
const edm::spacepoint_collection::const_device::const_proxy_type
87-
other_spT = spacepoints.at(
88-
sp_grid.bin(other_spT_idx.bin_idx)[other_spT_idx.sp_idx]);
81+
other_spT = spacepoints.at(other_triplet.spT);
8982

9083
// compared top SP should have at least deltaRMin distance
9184
const scalar otherTop_r = other_spT.radius();

device/common/include/traccc/seeding/device/update_triplet_weights.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ TRACCC_HOST_DEVICE
3737
inline void update_triplet_weights(
3838
global_index_t globalIndex, const seedfilter_config& filter_config,
3939
const edm::spacepoint_collection::const_view& spacepoints,
40-
const traccc::details::spacepoint_grid_types::const_view& sp_view,
4140
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4241
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
4342
device_triplet_collection_types::view triplet_view);

device/cuda/src/seeding/seed_finding.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ __global__ void find_triplets(
111111
__global__ void update_triplet_weights(
112112
seedfilter_config filter_config,
113113
edm::spacepoint_collection::const_view spacepoints,
114-
traccc::details::spacepoint_grid_types::const_view sp_grid,
115114
device::triplet_counter_spM_collection_types::const_view spM_tc,
116115
device::triplet_counter_collection_types::const_view midBot_tc,
117116
device::device_triplet_collection_types::view triplet_view) {
@@ -123,8 +122,8 @@ __global__ void update_triplet_weights(
123122
scalar* dataPos = &data[threadIdx.x * filter_config.compatSeedLimit];
124123

125124
device::update_triplet_weights(details::global_index1(), filter_config,
126-
spacepoints, sp_grid, spM_tc, midBot_tc,
127-
dataPos, triplet_view);
125+
spacepoints, spM_tc, midBot_tc, dataPos,
126+
triplet_view);
128127
}
129128

130129
/// CUDA kernel for running @c traccc::device::select_seeds
@@ -139,9 +138,10 @@ __global__ void select_seeds(
139138

140139
// Array for temporary storage of triplets for comparing within seed
141140
// selecting kernel
142-
extern __shared__ triplet data2[];
141+
extern __shared__ device::device_triplet data2[];
143142
// Each thread uses max_triplets_per_spM elements of the array
144-
triplet* dataPos = &data2[threadIdx.x * filter_config.max_triplets_per_spM];
143+
device::device_triplet* dataPos =
144+
&data2[threadIdx.x * filter_config.max_triplets_per_spM];
145145

146146
device::select_seeds(details::global_index1(), filter_config, spacepoints,
147147
sp_view, spM_tc, midBot_tc, triplet_view, dataPos,
@@ -336,7 +336,7 @@ edm::seed_collection::buffer seed_finding::operator()(
336336
nWeightUpdatingBlocks, nWeightUpdatingThreads,
337337
sizeof(scalar) * m_seedfilter_config.compatSeedLimit *
338338
nWeightUpdatingThreads,
339-
stream>>>(m_seedfilter_config, spacepoints_view, g2_view,
339+
stream>>>(m_seedfilter_config, spacepoints_view,
340340
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341341
triplet_buffer);
342342
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
@@ -356,7 +356,7 @@ edm::seed_collection::buffer seed_finding::operator()(
356356

357357
// Create seeds out of selected triplets
358358
kernels::select_seeds<<<nSeedSelectingBlocks, nSeedSelectingThreads,
359-
sizeof(triplet) *
359+
sizeof(device::device_triplet) *
360360
m_seedfilter_config.max_triplets_per_spM *
361361
nSeedSelectingThreads,
362362
stream>>>(m_seedfilter_config, spacepoints_view,

0 commit comments

Comments
 (0)