Skip to content

Commit 30cae55

Browse files
committed
Implement a best-for-n-spacepoints cut on seeds
This PR implements a cut that requires a seed to be the best seed for at least $n$ of its constituent spacepoints, which mirrors the cut that is used in ACTS. This cut is a specialisation of the cut in #1082, implemented in a way that should allow it to run faster.
1 parent 4c4de9f commit 30cae55

File tree

4 files changed

+137
-11
lines changed

4 files changed

+137
-11
lines changed

core/include/traccc/seeding/detail/seeding_config.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ struct seedfilter_config {
203203
// how often do you want to increase the weight of a seed for finding a
204204
// compatible seed?
205205
size_t compatSeedLimit = 2;
206+
// Number of consituents spacepoints for which a seed must be the best
207+
// seed, where "best" is defined by the weight. Valid range is
208+
// [0, 1, 2, 3] where 0 disables the cut.
209+
unsigned int minNumTimesHighestWeight = 1;
206210
// Tool to apply experiment specific cuts on collected middle space points
207211

208212
size_t max_triplets_per_spM = 5;

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ inline void select_seeds(
6666
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6767
const triplet_counter_collection_types::const_view& tc_view,
6868
const device_triplet_collection_types::const_view& triplet_view,
69-
triplet* data, edm::seed_collection::view seed_view) {
69+
vecmem::data::vector_view<const scalar> highest_weight_view, triplet* data,
70+
edm::seed_collection::view seed_view) {
7071

7172
// Check if anything needs to be done.
7273
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
@@ -84,13 +85,18 @@ inline void select_seeds(
8485
sp_view);
8586

8687
const device_triplet_collection_types::const_device triplets(triplet_view);
88+
89+
const vecmem::device_vector<const scalar> highest_weights(
90+
highest_weight_view);
91+
8792
edm::seed_collection::device seeds_device(seed_view);
8893

8994
// Current work item = middle spacepoint
9095
const triplet_counter_spM spM_counter = triplet_counts_spM.at(globalIndex);
9196
const sp_location spM_loc = spM_counter.spM;
97+
const unsigned int spM_idx = sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx];
9298
const edm::spacepoint_collection::const_device::const_proxy_type spM =
93-
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
99+
spacepoints.at(spM_idx);
94100

95101
// Number of triplets added for this spM
96102
unsigned int n_triplets_per_spM = 0;
@@ -106,10 +112,33 @@ inline void select_seeds(
106112
triplet_counts.at(static_cast<unsigned int>(aTriplet.counter_link))
107113
.spB;
108114
const sp_location spT_loc = aTriplet.spT;
115+
const unsigned int spB_idx =
116+
sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx];
109117
const edm::spacepoint_collection::const_device::const_proxy_type spB =
110-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
118+
spacepoints.at(spB_idx);
119+
const unsigned int spT_idx =
120+
sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx];
111121
const edm::spacepoint_collection::const_device::const_proxy_type spT =
112-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
122+
spacepoints.at(spT_idx);
123+
124+
if (filter_config.minNumTimesHighestWeight > 0) {
125+
assert(highest_weights.capacity() > 0);
126+
unsigned int highest_for_n_spacepoints = 0;
127+
if (highest_weights.at(spB_idx) <= aTriplet.weight) {
128+
highest_for_n_spacepoints++;
129+
}
130+
if (highest_weights.at(spM_idx) <= aTriplet.weight) {
131+
highest_for_n_spacepoints++;
132+
}
133+
if (highest_weights.at(spT_idx) <= aTriplet.weight) {
134+
highest_for_n_spacepoints++;
135+
}
136+
137+
if (highest_for_n_spacepoints <
138+
filter_config.minNumTimesHighestWeight) {
139+
continue;
140+
}
141+
}
113142

114143
// update weight of triplet
115144
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,

device/common/include/traccc/seeding/device/select_seeds.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ inline void select_seeds(
4141
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4242
const triplet_counter_collection_types::const_view& tc_view,
4343
const device_triplet_collection_types::const_view& triplet_view,
44-
triplet* data, edm::seed_collection::view seed_view);
44+
vecmem::data::vector_view<const scalar> highest_weight_view, triplet* data,
45+
edm::seed_collection::view seed_view);
4546

4647
} // namespace traccc::device
4748

device/cuda/src/seeding/seed_finding.cu

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ __global__ void select_seeds(
135135
device::triplet_counter_spM_collection_types::const_view spM_tc,
136136
device::triplet_counter_collection_types::const_view midBot_tc,
137137
device::device_triplet_collection_types::view triplet_view,
138+
vecmem::data::vector_view<const scalar> highest_weight_view,
138139
edm::seed_collection::view seed_view) {
139140

140141
// Array for temporary storage of triplets for comparing within seed
@@ -144,10 +145,72 @@ __global__ void select_seeds(
144145
triplet* dataPos = &data2[threadIdx.x * filter_config.max_triplets_per_spM];
145146

146147
device::select_seeds(details::global_index1(), filter_config, spacepoints,
147-
sp_view, spM_tc, midBot_tc, triplet_view, dataPos,
148-
seed_view);
148+
sp_view, spM_tc, midBot_tc, triplet_view,
149+
highest_weight_view, dataPos, seed_view);
149150
}
150151

152+
__global__ void collect_highest_weight_per_spacepoint(
153+
vecmem::data::vector_view<scalar> highest_weight_view,
154+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
155+
device::device_triplet_collection_types::const_view triplet_view) {
156+
vecmem::device_vector<scalar> highest_weights(highest_weight_view);
157+
const traccc::details::spacepoint_grid_types::const_device sp_grid(
158+
sp_grid_view);
159+
const device::device_triplet_collection_types::const_device triplets(
160+
triplet_view);
161+
162+
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
163+
164+
const unsigned int seed_idx = tid / 3;
165+
const unsigned int local_spacepoint_idx = tid % 3;
166+
167+
if (seed_idx >= triplets.size()) {
168+
return;
169+
}
170+
171+
sp_location global_spacepoint_loc;
172+
173+
if (local_spacepoint_idx == 0) {
174+
global_spacepoint_loc = triplets.at(seed_idx).spB;
175+
} else if (local_spacepoint_idx == 1) {
176+
global_spacepoint_loc = triplets.at(seed_idx).spM;
177+
} else if (local_spacepoint_idx == 2) {
178+
global_spacepoint_loc = triplets.at(seed_idx).spT;
179+
} else {
180+
__builtin_unreachable();
181+
}
182+
183+
const unsigned int global_spacepoint_idx = sp_grid.bin(
184+
global_spacepoint_loc.bin_idx)[global_spacepoint_loc.sp_idx];
185+
186+
static_assert(sizeof(scalar) == 4 || sizeof(scalar) == 8);
187+
using cas_type = std::conditional_t<sizeof(scalar) == 4, unsigned int,
188+
unsigned long long int>;
189+
190+
/*
191+
* The following is simply an implementation of atomic max.
192+
*/
193+
scalar& weight_loc = highest_weights.at(global_spacepoint_idx);
194+
cas_type* weight_raw = reinterpret_cast<cas_type*>(&weight_loc);
195+
196+
vecmem::device_atomic_ref<cas_type> atomic(*weight_raw);
197+
198+
cas_type current_weight_raw = atomic.load();
199+
scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
200+
const scalar own_weight = triplets.at(seed_idx).weight;
201+
const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
202+
203+
while (own_weight > current_weight) {
204+
const bool res =
205+
atomic.compare_exchange_strong(current_weight_raw, own_weight_raw);
206+
207+
if (res) {
208+
return;
209+
} else {
210+
current_weight = std::bit_cast<scalar>(current_weight_raw);
211+
}
212+
}
213+
}
151214
} // namespace kernels
152215

153216
namespace details {
@@ -184,6 +247,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184247
return {0, m_mr.main};
185248
}
186249

250+
// Total number of spacepoints, including those not in the grid.
251+
const auto total_num_spacepoints = m_copy.get_size(spacepoints_view);
252+
187253
// Set up the doublet counter buffer.
188254
device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189255
num_spacepoints, m_mr.main, vecmem::data::buffer_type::resizable};
@@ -340,6 +406,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340406
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341407
triplet_buffer);
342408
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
409+
TRACCC_CUDA_ERROR_CHECK(cudaStreamSynchronize(stream));
410+
411+
vecmem::data::vector_buffer<scalar> highest_weight_buffer(0, m_mr.main);
412+
413+
if (m_seedfilter_config.minNumTimesHighestWeight > 0) {
414+
highest_weight_buffer = {total_num_spacepoints, m_mr.main};
415+
416+
/*
417+
* NOTE: The value 0b11100000 is chosen because it, when repeated four
418+
* times to create a 32-bit floating point value, evaluates to the
419+
* number 0b11100000111000001110000011100000 which is very small, much
420+
* smaller than any seed weight should ever be. When broadcast eight
421+
* times to a 64-bit number it also produces an incredibly small value.
422+
*/
423+
m_copy.memset(highest_weight_buffer, 0b11100000)->wait();
424+
425+
const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets;
426+
const unsigned int num_threads = 256;
427+
const unsigned int num_blocks =
428+
(num_votes + num_threads - 1) / num_threads;
429+
430+
kernels::collect_highest_weight_per_spacepoint<<<
431+
num_blocks, num_threads, 0, stream>>>(highest_weight_buffer,
432+
g2_view, triplet_buffer);
433+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
434+
}
343435

344436
// Create result object: collection of seeds
345437
edm::seed_collection::buffer seed_buffer(
@@ -359,10 +451,10 @@ edm::seed_collection::buffer seed_finding::operator()(
359451
sizeof(triplet) *
360452
m_seedfilter_config.max_triplets_per_spM *
361453
nSeedSelectingThreads,
362-
stream>>>(m_seedfilter_config, spacepoints_view,
363-
g2_view, triplet_counter_spM_buffer,
364-
triplet_counter_midBot_buffer,
365-
triplet_buffer, seed_buffer);
454+
stream>>>(
455+
m_seedfilter_config, spacepoints_view, g2_view,
456+
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
457+
triplet_buffer, highest_weight_buffer, seed_buffer);
366458
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
367459

368460
return seed_buffer;

0 commit comments

Comments
 (0)