Skip to content

Commit 8d0d1a5

Browse files
committed
Implement a best-for-n-spacepoints cut on seeds
This PR implements a cut that requires a seed to be the best seed for at least $n$ of its constituent spacepoints, which mirrors the cut that is used in ACTS. This cut is a specialisation of the cut in #1082, implemented in a way that should allow it to run faster.
1 parent a2dbda9 commit 8d0d1a5

File tree

4 files changed

+122
-3
lines changed

4 files changed

+122
-3
lines changed

core/include/traccc/seeding/detail/seeding_config.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ struct seedfilter_config {
200200
// how often do you want to increase the weight of a seed for finding a
201201
// compatible seed?
202202
size_t compatSeedLimit = 2;
203+
// Number of consituents spacepoints for which a seed must be the best
204+
// seed, where "best" is defined by the weight. Valid range is
205+
// [0, 1, 2, 3] where 0 disables the cut.
206+
unsigned int minNumTimesHighestWeight = 1;
207+
// Tool to apply experiment specific cuts on collected middle space points
203208

204209
// seed weight increase
205210
float good_spB_min_radius = 150.f * unit<float>::mm;

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ inline void select_seeds(
6767
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6868
const triplet_counter_collection_types::const_view& tc_view,
6969
const device_triplet_collection_types::const_view& triplet_view,
70+
vecmem::data::vector_view<const scalar> highest_weight_view,
7071
device_triplet* data, edm::seed_collection::view seed_view) {
7172

7273
// Check if anything needs to be done.
@@ -85,6 +86,10 @@ inline void select_seeds(
8586
sp_view);
8687

8788
const device_triplet_collection_types::const_device triplets(triplet_view);
89+
90+
const vecmem::device_vector<const scalar> highest_weights(
91+
highest_weight_view);
92+
8893
edm::seed_collection::device seeds_device(seed_view);
8994

9095
// Current work item = middle spacepoint
@@ -111,6 +116,25 @@ inline void select_seeds(
111116
const edm::spacepoint_collection::const_device::const_proxy_type spT =
112117
spacepoints.at(spT_idx);
113118

119+
if (filter_config.minNumTimesHighestWeight > 0) {
120+
assert(highest_weights.capacity() > 0);
121+
unsigned int highest_for_n_spacepoints = 0;
122+
if (highest_weights.at(spB_idx) <= aTriplet.weight) {
123+
highest_for_n_spacepoints++;
124+
}
125+
if (highest_weights.at(spM_idx) <= aTriplet.weight) {
126+
highest_for_n_spacepoints++;
127+
}
128+
if (highest_weights.at(spT_idx) <= aTriplet.weight) {
129+
highest_for_n_spacepoints++;
130+
}
131+
132+
if (highest_for_n_spacepoints <
133+
filter_config.minNumTimesHighestWeight) {
134+
continue;
135+
}
136+
}
137+
114138
// update weight of triplet
115139
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
116140
aTriplet.weight);

device/common/include/traccc/seeding/device/select_seeds.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ inline void select_seeds(
4242
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4343
const triplet_counter_collection_types::const_view& tc_view,
4444
const device_triplet_collection_types::const_view& triplet_view,
45-
triplet* data, edm::seed_collection::view seed_view);
45+
vecmem::data::vector_view<const scalar> highest_weight_view, triplet* data,
46+
edm::seed_collection::view seed_view);
4647

4748
} // namespace traccc::device
4849

device/cuda/src/seeding/seed_finding.cu

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ __global__ void select_seeds(
134134
device::triplet_counter_spM_collection_types::const_view spM_tc,
135135
device::triplet_counter_collection_types::const_view midBot_tc,
136136
device::device_triplet_collection_types::view triplet_view,
137+
vecmem::data::vector_view<const scalar> highest_weight_view,
137138
edm::seed_collection::view seed_view) {
138139

139140
// Array for temporary storage of triplets for comparing within seed
@@ -145,9 +146,68 @@ __global__ void select_seeds(
145146

146147
device::select_seeds(details::global_index1(), finder_config, filter_config,
147148
spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148-
dataPos, seed_view);
149+
highest_weight_view, dataPos, seed_view);
149150
}
150151

152+
__global__ void collect_highest_weight_per_spacepoint(
153+
vecmem::data::vector_view<scalar> highest_weight_view,
154+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
155+
device::device_triplet_collection_types::const_view triplet_view) {
156+
vecmem::device_vector<scalar> highest_weights(highest_weight_view);
157+
const traccc::details::spacepoint_grid_types::const_device sp_grid(
158+
sp_grid_view);
159+
const device::device_triplet_collection_types::const_device triplets(
160+
triplet_view);
161+
162+
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
163+
164+
const unsigned int seed_idx = tid / 3;
165+
const unsigned int local_spacepoint_idx = tid % 3;
166+
167+
if (seed_idx >= triplets.size()) {
168+
return;
169+
}
170+
171+
unsigned int global_spacepoint_idx;
172+
173+
if (local_spacepoint_idx == 0) {
174+
global_spacepoint_idx = triplets.at(seed_idx).spB;
175+
} else if (local_spacepoint_idx == 1) {
176+
global_spacepoint_idx = triplets.at(seed_idx).spM;
177+
} else if (local_spacepoint_idx == 2) {
178+
global_spacepoint_idx = triplets.at(seed_idx).spT;
179+
} else {
180+
__builtin_unreachable();
181+
}
182+
183+
static_assert(sizeof(scalar) == 4 || sizeof(scalar) == 8);
184+
using cas_type = std::conditional_t<sizeof(scalar) == 4, unsigned int,
185+
unsigned long long int>;
186+
187+
/*
188+
* The following is simply an implementation of atomic max.
189+
*/
190+
scalar& weight_loc = highest_weights.at(global_spacepoint_idx);
191+
cas_type* weight_raw = reinterpret_cast<cas_type*>(&weight_loc);
192+
193+
vecmem::device_atomic_ref<cas_type> atomic(*weight_raw);
194+
195+
cas_type current_weight_raw = atomic.load();
196+
scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
197+
const scalar own_weight = triplets.at(seed_idx).weight;
198+
const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
199+
200+
while (own_weight > current_weight) {
201+
const bool res =
202+
atomic.compare_exchange_strong(current_weight_raw, own_weight_raw);
203+
204+
if (res) {
205+
return;
206+
} else {
207+
current_weight = std::bit_cast<scalar>(current_weight_raw);
208+
}
209+
}
210+
}
151211
} // namespace kernels
152212

153213
namespace details {
@@ -184,6 +244,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184244
return {0, m_mr.main};
185245
}
186246

247+
// Total number of spacepoints, including those not in the grid.
248+
const auto total_num_spacepoints = m_copy.get_size(spacepoints_view);
249+
187250
// Set up the doublet counter buffer.
188251
device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189252
num_spacepoints, m_mr.main, vecmem::data::buffer_type::resizable};
@@ -340,6 +403,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340403
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341404
triplet_buffer);
342405
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
406+
TRACCC_CUDA_ERROR_CHECK(cudaStreamSynchronize(stream));
407+
408+
vecmem::data::vector_buffer<scalar> highest_weight_buffer(0, m_mr.main);
409+
410+
if (m_seedfilter_config.minNumTimesHighestWeight > 0) {
411+
highest_weight_buffer = {total_num_spacepoints, m_mr.main};
412+
413+
/*
414+
* NOTE: The value 0b11100000 is chosen because it, when repeated four
415+
* times to create a 32-bit floating point value, evaluates to the
416+
* number 0b11100000111000001110000011100000 which is very small, much
417+
* smaller than any seed weight should ever be. When broadcast eight
418+
* times to a 64-bit number it also produces an incredibly small value.
419+
*/
420+
m_copy.memset(highest_weight_buffer, 0b11100000)->wait();
421+
422+
const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets;
423+
const unsigned int num_threads = 256;
424+
const unsigned int num_blocks =
425+
(num_votes + num_threads - 1) / num_threads;
426+
427+
kernels::collect_highest_weight_per_spacepoint<<<
428+
num_blocks, num_threads, 0, stream>>>(highest_weight_buffer,
429+
g2_view, triplet_buffer);
430+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
431+
}
343432

344433
// Create result object: collection of seeds
345434
edm::seed_collection::buffer seed_buffer(
@@ -362,7 +451,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362451
stream>>>(
363452
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364453
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365-
triplet_buffer, seed_buffer);
454+
triplet_buffer, highest_weight_buffer, seed_buffer);
366455
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
367456

368457
return seed_buffer;

0 commit comments

Comments
 (0)