Skip to content

Commit 9b5e684

Browse files
authored
Merge pull request #1097 from stephenswat/feat/spbm_in_triplet
Embed bottom and middle spacepoint in triplet
2 parents ed4e53d + 0f6e85c commit 9b5e684

File tree

10 files changed

+76
-85
lines changed

10 files changed

+76
-85
lines changed

core/include/traccc/seeding/seed_selecting_helper.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,12 @@ struct seed_selecting_helper {
7070
/// @param seed current seed to possibly cut
7171
///
7272
/// @return boolean value
73+
template <typename spacepoint_type>
7374
static TRACCC_HOST_DEVICE bool cut_per_middle_sp(
74-
const seedfilter_config& filter_config,
75-
const edm::spacepoint_collection::const_device& spacepoints,
76-
const details::spacepoint_grid_types::const_device& grid,
77-
const triplet& seed) {
75+
const seedfilter_config& filter_config, const spacepoint_type& spB,
76+
const scalar weight) {
7877

79-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
80-
spacepoints.at(grid.bin(seed.sp1.bin_idx)[seed.sp1.sp_idx]);
81-
82-
return (seed.weight > filter_config.seed_min_weight ||
78+
return (weight > filter_config.seed_min_weight ||
8379
spB.radius() > filter_config.spB_min_radius);
8480
}
8581
};

core/src/seeding/seed_filtering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,14 @@ void seed_filtering::operator()(
9595
std::min(triplets_passing_single_seed_cuts.size(),
9696
static_cast<std::size_t>(m_finder_config.maxSeedsPerSpM));
9797
for (std::size_t i = 1; i < itLength; ++i) {
98+
const traccc::details::spacepoint_grid_types::const_device
99+
sp_grid_accessor(sp_grid_data);
100+
const auto& this_seed = triplets_passing_single_seed_cuts[i].get();
98101
if (seed_selecting_helper::cut_per_middle_sp(
99-
m_filter_config, spacepoints, sp_grid_data,
100-
triplets_passing_single_seed_cuts[i])) {
102+
m_filter_config,
103+
spacepoints.at(sp_grid_accessor.bin(
104+
this_seed.sp1.bin_idx)[this_seed.sp1.sp_idx]),
105+
this_seed.weight)) {
101106
triplets_passing_final_cuts.push_back(
102107
triplets_passing_single_seed_cuts[i]);
103108
}

device/alpaka/src/seeding/seed_finding.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ struct UpdateTripletWeights {
134134
ALPAKA_FN_ACC void operator()(
135135
TAcc const& acc, seedfilter_config filter_config,
136136
edm::spacepoint_collection::const_view spacepoints,
137-
traccc::details::spacepoint_grid_types::const_view sp_grid,
138137
device::triplet_counter_spM_collection_types::const_view spM_tc,
139138
device::triplet_counter_collection_types::const_view midBot_tc,
140139
device::device_triplet_collection_types::view triplet_view) const {
@@ -151,8 +150,8 @@ struct UpdateTripletWeights {
151150
scalar* dataPos = &data[localThreadIdx * filter_config.compatSeedLimit];
152151

153152
device::update_triplet_weights(globalThreadIdx, filter_config,
154-
spacepoints, sp_grid, spM_tc, midBot_tc,
155-
dataPos, triplet_view);
153+
spacepoints, spM_tc, midBot_tc, dataPos,
154+
triplet_view);
156155
}
157156
};
158157

@@ -175,10 +174,12 @@ struct SelectSeeds {
175174

176175
// Array for temporary storage of quality parameters for comparing
177176
// triplets within weight updating kernel
178-
triplet* const data = ::alpaka::getDynSharedMem<triplet>(acc);
177+
device::device_triplet* const data =
178+
::alpaka::getDynSharedMem<device::device_triplet>(acc);
179179

180-
// Each thread uses maxSeedsPerSpM elements of the array
181-
triplet* dataPos = &data[localThreadIdx * finder_config.maxSeedsPerSpM];
180+
// Each thread uses max_triplets_per_spM elements of the array
181+
device::device_triplet* dataPos =
182+
&data[localThreadIdx * finder_config.maxSeedsPerSpM];
182183

183184
device::select_seeds(globalThreadIdx, finder_config, filter_config,
184185
spacepoints, sp_view, spM_tc, midBot_tc,
@@ -360,7 +361,7 @@ edm::seed_collection::buffer seed_finding::operator()(
360361

361362
// Update the weights of all spacepoint triplets.
362363
::alpaka::exec<Acc>(queue, workDiv, kernels::UpdateTripletWeights{},
363-
m_seedfilter_config, spacepoints_view, g2_view,
364+
m_seedfilter_config, spacepoints_view,
364365
vecmem::get_data(triplet_counter_spM_buffer),
365366
vecmem::get_data(triplet_counter_midBot_buffer),
366367
vecmem::get_data(triplet_buffer));
@@ -419,7 +420,7 @@ struct BlockSharedMemDynSizeBytes<traccc::alpaka::kernels::SelectSeeds, TAcc> {
419420
) -> std::size_t {
420421
return static_cast<std::size_t>(finder_config.maxSeedsPerSpM *
421422
blockThreadExtent.prod()) *
422-
sizeof(traccc::triplet);
423+
sizeof(traccc::device::device_triplet);
423424
}
424425
};
425426

device/common/include/traccc/edm/device/device_triplet.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace traccc::device {
1717
/// Triplets of bottom, middle and top spacepoints
1818
struct device_triplet {
1919
// top spacepoint location in internal spacepoint container
20-
sp_location spT;
20+
unsigned int spB, spM, spT;
2121

2222
using link_type = device::triplet_counter_collection_types::host::size_type;
2323
/// Link to triplet counter where the middle and bottom spacepoints are

device/common/include/traccc/seeding/device/impl/find_triplets.ipp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ inline void find_triplets(
5656
const sp_location spB_loc = mid_bot_counter.spB;
5757

5858
// middle spacepoint
59+
const unsigned int spM_idx = sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx];
5960
const edm::spacepoint_collection::const_device::const_proxy_type spM =
60-
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
61+
spacepoints.at(spM_idx);
6162

6263
// bottom spacepoint
64+
const unsigned int spB_idx = sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx];
6365
const edm::spacepoint_collection::const_device::const_proxy_type spB =
64-
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
66+
spacepoints.at(spB_idx);
6567

6668
// Set up the device result collection
6769
device_triplet_collection_types::device triplets(triplet_view);
@@ -95,8 +97,10 @@ inline void find_triplets(
9597
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
9698
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
9799

100+
const unsigned int spT_idx =
101+
sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx];
98102
const edm::spacepoint_collection::const_device::const_proxy_type spT =
99-
spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
103+
spacepoints.at(spT_idx);
100104

101105
// Apply the conformal transformation to middle-top doublet
102106
const traccc::lin_circle lt =
@@ -110,7 +114,7 @@ inline void find_triplets(
110114

111115
// Add triplet to jagged vector
112116
triplets.at(posTriplets++) = device_triplet(
113-
{spT_loc, globalIndex, curvature,
117+
{spB_idx, spM_idx, spT_idx, globalIndex, curvature,
114118
-impact_parameter * filter_config.impactWeightFactor,
115119
lb.Zo()});
116120
}

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace traccc::device {
2020
namespace details {
2121
// Finding minimum element algorithm
2222
template <typename Comparator>
23-
TRACCC_HOST_DEVICE std::size_t min_elem(const triplet* arr,
23+
TRACCC_HOST_DEVICE std::size_t min_elem(const device_triplet* arr,
2424
const std::size_t begin_idx,
2525
const std::size_t end_idx,
2626
Comparator comp) {
@@ -38,11 +38,11 @@ TRACCC_HOST_DEVICE std::size_t min_elem(const triplet* arr,
3838

3939
// Sorting algorithm for sorting seeds in the local memory
4040
template <typename Comparator>
41-
TRACCC_HOST_DEVICE void insertionSort(triplet* arr,
41+
TRACCC_HOST_DEVICE void insertionSort(device_triplet* arr,
4242
const unsigned int begin_idx,
4343
const unsigned int n, Comparator comp) {
4444
int j = 0;
45-
triplet key = arr[begin_idx];
45+
device_triplet key = arr[begin_idx];
4646
for (unsigned int i = 0; i < n; ++i) {
4747
key = arr[begin_idx + i];
4848
j = static_cast<int>(i) - 1;
@@ -67,7 +67,7 @@ inline void select_seeds(
6767
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6868
const triplet_counter_collection_types::const_view& tc_view,
6969
const device_triplet_collection_types::const_view& triplet_view,
70-
triplet* data, edm::seed_collection::view seed_view) {
70+
device_triplet* data, edm::seed_collection::view seed_view) {
7171

7272
// Check if anything needs to be done.
7373
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
@@ -90,8 +90,9 @@ inline void select_seeds(
9090
// Current work item = middle spacepoint
9191
const triplet_counter_spM spM_counter = triplet_counts_spM.at(globalIndex);
9292
const sp_location spM_loc = spM_counter.spM;
93+
const unsigned int spM_idx = sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx];
9394
const edm::spacepoint_collection::const_device::const_proxy_type spM =
94-
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
95+
spacepoints.at(spM_idx);
9596

9697
// Number of triplets added for this spM
9798
unsigned int n_triplets_per_spM = 0;
@@ -103,14 +104,12 @@ inline void select_seeds(
103104
device_triplet aTriplet = triplets[i];
104105

105106
// spacepoints bottom and top for this triplet
106-
const sp_location spB_loc =
107-
triplet_counts.at(static_cast<unsigned int>(aTriplet.counter_link))
108-
.spB;
109-
const sp_location spT_loc = aTriplet.spT;
107+
const unsigned int spB_idx = aTriplet.spB;
110108
const edm::spacepoint_collection::const_device::const_proxy_type spB =
111-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
109+
spacepoints.at(spB_idx);
110+
const unsigned int spT_idx = aTriplet.spT;
112111
const edm::spacepoint_collection::const_device::const_proxy_type spT =
113-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
112+
spacepoints.at(spT_idx);
114113

115114
// update weight of triplet
116115
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
@@ -126,61 +125,54 @@ inline void select_seeds(
126125
// the triplet with the lowest weight is removed
127126
if (n_triplets_per_spM >= finder_config.maxSeedsPerSpM) {
128127

129-
const std::size_t min_index =
130-
details::min_elem(data, 0, finder_config.maxSeedsPerSpM,
131-
[](const triplet lhs, const triplet rhs) {
132-
return lhs.weight > rhs.weight;
133-
});
128+
const std::size_t min_index = details::min_elem(
129+
data, 0, finder_config.maxSeedsPerSpM,
130+
[](const device_triplet& lhs, const device_triplet& rhs) {
131+
return lhs.weight > rhs.weight;
132+
});
134133

135134
const scalar& min_weight = data[min_index].weight;
136135

137136
if (aTriplet.weight > min_weight) {
138-
data[min_index] = {spB_loc, spM_loc,
139-
spT_loc, aTriplet.curvature,
140-
aTriplet.weight, aTriplet.z_vertex};
137+
data[min_index] = aTriplet;
141138
}
142139
}
143140

144141
// if the number of good triplets is below the threshold, add
145142
// the current triplet to the array
146143
else if (n_triplets_per_spM < finder_config.maxSeedsPerSpM) {
147-
data[n_triplets_per_spM] = {spB_loc, spM_loc,
148-
spT_loc, aTriplet.curvature,
149-
aTriplet.weight, aTriplet.z_vertex};
144+
data[n_triplets_per_spM] = aTriplet;
150145
n_triplets_per_spM++;
151146
}
152147
}
153148

154149
// sort the triplets per spM
155150
details::insertionSort(
156151
data, 0, n_triplets_per_spM,
157-
traccc::details::triplet_sorter{spacepoints, sp_device});
152+
[](const device_triplet& lhs, const device_triplet& rhs) {
153+
return lhs.weight > rhs.weight;
154+
});
158155

159156
// the number of good seed per compatible middle spacepoint
160157
unsigned int n_seeds_per_spM = 0;
161158

162159
// iterate over the good triplets for final selection of seeds
163160
for (unsigned int i = 0; i < n_triplets_per_spM; ++i) {
164-
const triplet& aTriplet = data[i];
165-
const sp_location& spB_loc = aTriplet.sp1;
166-
const sp_location& spT_loc = aTriplet.sp3;
161+
const device_triplet& aTriplet = data[i];
167162

168163
// if the number of seeds reaches the threshold, break
169164
if (n_seeds_per_spM >= finder_config.maxSeedsPerSpM + 1) {
170165
break;
171166
}
172167

173168
// check if it is a good triplet
174-
if (seed_selecting_helper::cut_per_middle_sp(filter_config, spacepoints,
175-
sp_device, aTriplet) ||
169+
if (seed_selecting_helper::cut_per_middle_sp(
170+
filter_config, spacepoints.at(aTriplet.spB), aTriplet.weight) ||
176171
n_seeds_per_spM == 0) {
177172

178173
n_seeds_per_spM++;
179174

180-
seeds_device.push_back(
181-
{sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx],
182-
sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx],
183-
sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]});
175+
seeds_device.push_back({aTriplet.spB, aTriplet.spM, aTriplet.spT});
184176
}
185177
}
186178
}

device/common/include/traccc/seeding/device/impl/update_triplet_weights.ipp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ TRACCC_HOST_DEVICE
1919
inline void update_triplet_weights(
2020
const global_index_t globalIndex, const seedfilter_config& filter_config,
2121
const edm::spacepoint_collection::const_view& spacepoints_view,
22-
const traccc::details::spacepoint_grid_types::const_view& sp_view,
2322
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
2423
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
2524
device_triplet_collection_types::view triplet_view) {
@@ -33,7 +32,6 @@ inline void update_triplet_weights(
3332
// Set up the device containers
3433
const edm::spacepoint_collection::const_device spacepoints{
3534
spacepoints_view};
36-
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
3735
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
3836
spM_tc_view);
3937
const triplet_counter_collection_types::const_device triplet_counts(
@@ -42,11 +40,8 @@ inline void update_triplet_weights(
4240
// Current work item
4341
device_triplet this_triplet = triplets.at(globalIndex);
4442

45-
const sp_location& spT_idx = this_triplet.spT;
46-
4743
const edm::spacepoint_collection::const_device::const_proxy_type
48-
current_spT =
49-
spacepoints.at(sp_grid.bin(spT_idx.bin_idx)[spT_idx.sp_idx]);
44+
current_spT = spacepoints.at(this_triplet.spT);
5045

5146
const scalar currentTop_r = current_spT.radius();
5247

@@ -82,10 +77,8 @@ inline void update_triplet_weights(
8277
}
8378

8479
const device_triplet other_triplet = triplets[i];
85-
const sp_location other_spT_idx = other_triplet.spT;
8680
const edm::spacepoint_collection::const_device::const_proxy_type
87-
other_spT = spacepoints.at(
88-
sp_grid.bin(other_spT_idx.bin_idx)[other_spT_idx.sp_idx]);
81+
other_spT = spacepoints.at(other_triplet.spT);
8982

9083
// compared top SP should have at least deltaRMin distance
9184
const scalar otherTop_r = other_spT.radius();

device/common/include/traccc/seeding/device/update_triplet_weights.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ TRACCC_HOST_DEVICE
3737
inline void update_triplet_weights(
3838
global_index_t globalIndex, const seedfilter_config& filter_config,
3939
const edm::spacepoint_collection::const_view& spacepoints,
40-
const traccc::details::spacepoint_grid_types::const_view& sp_view,
4140
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4241
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
4342
device_triplet_collection_types::view triplet_view);

device/cuda/src/seeding/seed_finding.cu

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ __global__ void find_triplets(
111111
__global__ void update_triplet_weights(
112112
seedfilter_config filter_config,
113113
edm::spacepoint_collection::const_view spacepoints,
114-
traccc::details::spacepoint_grid_types::const_view sp_grid,
115114
device::triplet_counter_spM_collection_types::const_view spM_tc,
116115
device::triplet_counter_collection_types::const_view midBot_tc,
117116
device::device_triplet_collection_types::view triplet_view) {
@@ -123,8 +122,8 @@ __global__ void update_triplet_weights(
123122
scalar* dataPos = &data[threadIdx.x * filter_config.compatSeedLimit];
124123

125124
device::update_triplet_weights(details::global_index1(), filter_config,
126-
spacepoints, sp_grid, spM_tc, midBot_tc,
127-
dataPos, triplet_view);
125+
spacepoints, spM_tc, midBot_tc, dataPos,
126+
triplet_view);
128127
}
129128

130129
/// CUDA kernel for running @c traccc::device::select_seeds
@@ -139,9 +138,10 @@ __global__ void select_seeds(
139138

140139
// Array for temporary storage of triplets for comparing within seed
141140
// selecting kernel
142-
extern __shared__ triplet data2[];
143-
// Each thread uses maxSeedsPerSpM elements of the array
144-
triplet* dataPos = &data2[threadIdx.x * finder_config.maxSeedsPerSpM];
141+
extern __shared__ device::device_triplet data2[];
142+
// Each thread uses max_triplets_per_spM elements of the array
143+
device::device_triplet* dataPos =
144+
&data2[threadIdx.x * finder_config.maxSeedsPerSpM];
145145

146146
device::select_seeds(details::global_index1(), finder_config, filter_config,
147147
spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
@@ -336,7 +336,7 @@ edm::seed_collection::buffer seed_finding::operator()(
336336
nWeightUpdatingBlocks, nWeightUpdatingThreads,
337337
sizeof(scalar) * m_seedfilter_config.compatSeedLimit *
338338
nWeightUpdatingThreads,
339-
stream>>>(m_seedfilter_config, spacepoints_view, g2_view,
339+
stream>>>(m_seedfilter_config, spacepoints_view,
340340
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341341
triplet_buffer);
342342
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
@@ -355,14 +355,14 @@ edm::seed_collection::buffer seed_finding::operator()(
355355
nSeedSelectingThreads;
356356

357357
// Create seeds out of selected triplets
358-
kernels::
359-
select_seeds<<<nSeedSelectingBlocks, nSeedSelectingThreads,
360-
sizeof(triplet) * m_seedfinder_config.maxSeedsPerSpM *
361-
nSeedSelectingThreads,
362-
stream>>>(
363-
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364-
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365-
triplet_buffer, seed_buffer);
358+
kernels::select_seeds<<<nSeedSelectingBlocks, nSeedSelectingThreads,
359+
sizeof(device::device_triplet) *
360+
m_seedfinder_config.maxSeedsPerSpM *
361+
nSeedSelectingThreads,
362+
stream>>>(
363+
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364+
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365+
triplet_buffer, seed_buffer);
366366
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
367367

368368
return seed_buffer;

0 commit comments

Comments
 (0)