@@ -134,6 +134,7 @@ __global__ void select_seeds(
134134 device::triplet_counter_spM_collection_types::const_view spM_tc,
135135 device::triplet_counter_collection_types::const_view midBot_tc,
136136 device::device_triplet_collection_types::view triplet_view,
137+ vecmem::data::vector_view<const scalar> highest_weight_view,
137138 edm::seed_collection::view seed_view) {
138139
139140 // Array for temporary storage of triplets for comparing within seed
@@ -145,9 +146,65 @@ __global__ void select_seeds(
145146
146147 device::select_seeds (details::global_index1 (), finder_config, filter_config,
147148 spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148- dataPos, seed_view);
149+ highest_weight_view, dataPos, seed_view);
149150}
150151
152+ __global__ void collect_highest_weight_per_spacepoint (
153+ vecmem::data::vector_view<scalar> highest_weight_view,
154+ device::device_triplet_collection_types::const_view triplet_view) {
155+ vecmem::device_vector<scalar> highest_weights (highest_weight_view);
156+ const device::device_triplet_collection_types::const_device triplets (
157+ triplet_view);
158+
159+ const unsigned int tid = blockIdx .x * blockDim .x + threadIdx .x ;
160+
161+ const unsigned int seed_idx = tid / 3 ;
162+ const unsigned int local_spacepoint_idx = tid % 3 ;
163+
164+ if (seed_idx >= triplets.size ()) {
165+ return ;
166+ }
167+
168+ unsigned int global_spacepoint_idx;
169+
170+ if (local_spacepoint_idx == 0 ) {
171+ global_spacepoint_idx = triplets.at (seed_idx).spB ;
172+ } else if (local_spacepoint_idx == 1 ) {
173+ global_spacepoint_idx = triplets.at (seed_idx).spM ;
174+ } else if (local_spacepoint_idx == 2 ) {
175+ global_spacepoint_idx = triplets.at (seed_idx).spT ;
176+ } else {
177+ __builtin_unreachable ();
178+ }
179+
180+ static_assert (sizeof (scalar) == 4 || sizeof (scalar) == 8 );
181+ using cas_type = std::conditional_t <sizeof (scalar) == 4 , unsigned int ,
182+ unsigned long long int >;
183+
184+ /*
185+ * The following is simply an implementation of atomic max.
186+ */
187+ scalar& weight_loc = highest_weights.at (global_spacepoint_idx);
188+ cas_type* weight_raw = reinterpret_cast <cas_type*>(&weight_loc);
189+
190+ vecmem::device_atomic_ref<cas_type> atomic (*weight_raw);
191+
192+ cas_type current_weight_raw = atomic.load ();
193+ scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
194+ const scalar own_weight = triplets.at (seed_idx).weight ;
195+ const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
196+
197+ while (own_weight > current_weight) {
198+ const bool res =
199+ atomic.compare_exchange_strong (current_weight_raw, own_weight_raw);
200+
201+ if (res) {
202+ return ;
203+ } else {
204+ current_weight = std::bit_cast<scalar>(current_weight_raw);
205+ }
206+ }
207+ }
151208} // namespace kernels
152209
153210namespace details {
@@ -184,6 +241,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184241 return {0 , m_mr.main };
185242 }
186243
244+ // Total number of spacepoints, including those not in the grid.
245+ const auto total_num_spacepoints = m_copy.get_size (spacepoints_view);
246+
187247 // Set up the doublet counter buffer.
188248 device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189249 num_spacepoints, m_mr.main , vecmem::data::buffer_type::resizable};
@@ -340,6 +400,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340400 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341401 triplet_buffer);
342402 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
403+ TRACCC_CUDA_ERROR_CHECK (cudaStreamSynchronize (stream));
404+
405+ vecmem::data::vector_buffer<scalar> highest_weight_buffer (0 , m_mr.main );
406+
407+ if (m_seedfilter_config.minNumTimesWeightCompatible > 0 ) {
408+ highest_weight_buffer = {total_num_spacepoints, m_mr.main };
409+
410+ /*
411+ * NOTE: The value 0b11100000 is chosen because it, when repeated four
412+ * times to create a 32-bit floating point value, evaluates to the
413+ * number 0b11100000111000001110000011100000 which is very small, much
414+ * smaller than any seed weight should ever be. When broadcast eight
415+ * times to a 64-bit number it also produces an incredibly small value.
416+ */
417+ m_copy.memset (highest_weight_buffer, 0b11100000 )->wait ();
418+
419+ const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets ;
420+ const unsigned int num_threads = 256 ;
421+ const unsigned int num_blocks =
422+ (num_votes + num_threads - 1 ) / num_threads;
423+
424+ kernels::collect_highest_weight_per_spacepoint<<<
425+ num_blocks, num_threads, 0 , stream>>> (highest_weight_buffer,
426+ triplet_buffer);
427+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
428+ }
343429
344430 // Create result object: collection of seeds
345431 edm::seed_collection::buffer seed_buffer (
@@ -362,7 +448,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362448 stream>>> (
363449 m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364450 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365- triplet_buffer, seed_buffer);
451+ triplet_buffer, highest_weight_buffer, seed_buffer);
366452 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
367453
368454 return seed_buffer;
0 commit comments