Skip to content

Commit 6d9afa3

Browse files
committed
Implement a best-for-n-spacepoints cut on seeds
This PR implements a cut that requires a seed to be the best seed for at least $n$ of its constituent spacepoints, which mirrors the cut that is used in ACTS. This cut is a specialisation of the cut in #1082, implemented in a way that should allow it to run faster.
1 parent f84a267 commit 6d9afa3

File tree

4 files changed

+132
-3
lines changed

4 files changed

+132
-3
lines changed

core/include/traccc/seeding/detail/seeding_config.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,13 @@ struct seedfilter_config {
203203
// how often do you want to increase the weight of a seed for finding a
204204
// compatible seed?
205205
size_t compatSeedLimit = 2;
206+
// Number of consituents spacepoints for which a seed must be the best
207+
// seed, where "best" is defined by the weight. Valid range is
208+
// [0, 1, 2, 3] where 0 disables the cut.
209+
unsigned int minNumTimesWeightCompatible = 1;
210+
// The minimum relative weight for a seed to be considered compatible with
211+
// one of its spacepoints.
212+
float minWeightCompatibilityFraction = 0.6f;
206213

207214
// seed weight increase
208215
float good_spB_min_radius = 150.f * unit<float>::mm;

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ inline void select_seeds(
6767
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6868
const triplet_counter_collection_types::const_view& tc_view,
6969
const device_triplet_collection_types::const_view& triplet_view,
70+
vecmem::data::vector_view<const scalar> highest_weight_view,
7071
device_triplet* data, edm::seed_collection::view seed_view) {
7172

7273
// Check if anything needs to be done.
@@ -85,6 +86,10 @@ inline void select_seeds(
8586
sp_view);
8687

8788
const device_triplet_collection_types::const_device triplets(triplet_view);
89+
90+
const vecmem::device_vector<const scalar> highest_weights(
91+
highest_weight_view);
92+
8893
edm::seed_collection::device seeds_device(seed_view);
8994

9095
// Current work item = middle spacepoint
@@ -111,6 +116,36 @@ inline void select_seeds(
111116
const edm::spacepoint_collection::const_device::const_proxy_type spT =
112117
spacepoints.at(spT_idx);
113118

119+
if (filter_config.minNumTimesWeightCompatible > 0) {
120+
assert(highest_weights.capacity() > 0);
121+
assert(aTriplet.weight >= 0.f);
122+
assert(highest_weights.at(spB_idx) >= 0.f);
123+
assert(highest_weights.at(spM_idx) >= 0.f);
124+
assert(highest_weights.at(spT_idx) >= 0.f);
125+
126+
unsigned int highest_for_n_spacepoints = 0;
127+
if (aTriplet.weight >=
128+
filter_config.minWeightCompatibilityFraction *
129+
highest_weights.at(spB_idx)) {
130+
highest_for_n_spacepoints++;
131+
}
132+
if (aTriplet.weight >=
133+
filter_config.minWeightCompatibilityFraction *
134+
highest_weights.at(spM_idx)) {
135+
highest_for_n_spacepoints++;
136+
}
137+
if (aTriplet.weight >=
138+
filter_config.minWeightCompatibilityFraction *
139+
highest_weights.at(spT_idx)) {
140+
highest_for_n_spacepoints++;
141+
}
142+
143+
if (highest_for_n_spacepoints <
144+
filter_config.minNumTimesWeightCompatible) {
145+
continue;
146+
}
147+
}
148+
114149
// update weight of triplet
115150
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
116151
aTriplet.weight);

device/common/include/traccc/seeding/device/select_seeds.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ inline void select_seeds(
4242
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4343
const triplet_counter_collection_types::const_view& tc_view,
4444
const device_triplet_collection_types::const_view& triplet_view,
45-
triplet* data, edm::seed_collection::view seed_view);
45+
vecmem::data::vector_view<const scalar> highest_weight_view, triplet* data,
46+
edm::seed_collection::view seed_view);
4647

4748
} // namespace traccc::device
4849

device/cuda/src/seeding/seed_finding.cu

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ __global__ void select_seeds(
134134
device::triplet_counter_spM_collection_types::const_view spM_tc,
135135
device::triplet_counter_collection_types::const_view midBot_tc,
136136
device::device_triplet_collection_types::view triplet_view,
137+
vecmem::data::vector_view<const scalar> highest_weight_view,
137138
edm::seed_collection::view seed_view) {
138139

139140
// Array for temporary storage of triplets for comparing within seed
@@ -145,9 +146,65 @@ __global__ void select_seeds(
145146

146147
device::select_seeds(details::global_index1(), finder_config, filter_config,
147148
spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148-
dataPos, seed_view);
149+
highest_weight_view, dataPos, seed_view);
149150
}
150151

152+
__global__ void collect_highest_weight_per_spacepoint(
153+
vecmem::data::vector_view<scalar> highest_weight_view,
154+
device::device_triplet_collection_types::const_view triplet_view) {
155+
vecmem::device_vector<scalar> highest_weights(highest_weight_view);
156+
const device::device_triplet_collection_types::const_device triplets(
157+
triplet_view);
158+
159+
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
160+
161+
const unsigned int seed_idx = tid / 3;
162+
const unsigned int local_spacepoint_idx = tid % 3;
163+
164+
if (seed_idx >= triplets.size()) {
165+
return;
166+
}
167+
168+
unsigned int global_spacepoint_idx;
169+
170+
if (local_spacepoint_idx == 0) {
171+
global_spacepoint_idx = triplets.at(seed_idx).spB;
172+
} else if (local_spacepoint_idx == 1) {
173+
global_spacepoint_idx = triplets.at(seed_idx).spM;
174+
} else if (local_spacepoint_idx == 2) {
175+
global_spacepoint_idx = triplets.at(seed_idx).spT;
176+
} else {
177+
__builtin_unreachable();
178+
}
179+
180+
static_assert(sizeof(scalar) == 4 || sizeof(scalar) == 8);
181+
using cas_type = std::conditional_t<sizeof(scalar) == 4, unsigned int,
182+
unsigned long long int>;
183+
184+
/*
185+
* The following is simply an implementation of atomic max.
186+
*/
187+
scalar& weight_loc = highest_weights.at(global_spacepoint_idx);
188+
cas_type* weight_raw = reinterpret_cast<cas_type*>(&weight_loc);
189+
190+
vecmem::device_atomic_ref<cas_type> atomic(*weight_raw);
191+
192+
cas_type current_weight_raw = atomic.load();
193+
scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
194+
const scalar own_weight = triplets.at(seed_idx).weight;
195+
const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
196+
197+
while (own_weight > current_weight) {
198+
const bool res =
199+
atomic.compare_exchange_strong(current_weight_raw, own_weight_raw);
200+
201+
if (res) {
202+
return;
203+
} else {
204+
current_weight = std::bit_cast<scalar>(current_weight_raw);
205+
}
206+
}
207+
}
151208
} // namespace kernels
152209

153210
namespace details {
@@ -184,6 +241,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184241
return {0, m_mr.main};
185242
}
186243

244+
// Total number of spacepoints, including those not in the grid.
245+
const auto total_num_spacepoints = m_copy.get_size(spacepoints_view);
246+
187247
// Set up the doublet counter buffer.
188248
device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189249
num_spacepoints, m_mr.main, vecmem::data::buffer_type::resizable};
@@ -340,6 +400,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340400
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341401
triplet_buffer);
342402
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
403+
TRACCC_CUDA_ERROR_CHECK(cudaStreamSynchronize(stream));
404+
405+
vecmem::data::vector_buffer<scalar> highest_weight_buffer(0, m_mr.main);
406+
407+
if (m_seedfilter_config.minNumTimesWeightCompatible > 0) {
408+
highest_weight_buffer = {total_num_spacepoints, m_mr.main};
409+
410+
/*
411+
* NOTE: The value 0b11100000 is chosen because it, when repeated four
412+
* times to create a 32-bit floating point value, evaluates to the
413+
* number 0b11100000111000001110000011100000 which is very small, much
414+
* smaller than any seed weight should ever be. When broadcast eight
415+
* times to a 64-bit number it also produces an incredibly small value.
416+
*/
417+
m_copy.memset(highest_weight_buffer, 0b11100000)->wait();
418+
419+
const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets;
420+
const unsigned int num_threads = 256;
421+
const unsigned int num_blocks =
422+
(num_votes + num_threads - 1) / num_threads;
423+
424+
kernels::collect_highest_weight_per_spacepoint<<<
425+
num_blocks, num_threads, 0, stream>>>(highest_weight_buffer,
426+
triplet_buffer);
427+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
428+
}
343429

344430
// Create result object: collection of seeds
345431
edm::seed_collection::buffer seed_buffer(
@@ -362,7 +448,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362448
stream>>>(
363449
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364450
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365-
triplet_buffer, seed_buffer);
451+
triplet_buffer, highest_weight_buffer, seed_buffer);
366452
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
367453

368454
return seed_buffer;

0 commit comments

Comments
 (0)