@@ -134,6 +134,7 @@ __global__ void select_seeds(
134134 device::triplet_counter_spM_collection_types::const_view spM_tc,
135135 device::triplet_counter_collection_types::const_view midBot_tc,
136136 device::device_triplet_collection_types::view triplet_view,
137+ vecmem::data::vector_view<const scalar> highest_weight_view,
137138 edm::seed_collection::view seed_view) {
138139
139140 // Array for temporary storage of triplets for comparing within seed
@@ -145,9 +146,68 @@ __global__ void select_seeds(
145146
146147 device::select_seeds (details::global_index1 (), finder_config, filter_config,
147148 spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148- dataPos, seed_view);
149+ highest_weight_view, dataPos, seed_view);
149150}
150151
152+ __global__ void collect_highest_weight_per_spacepoint (
153+ vecmem::data::vector_view<scalar> highest_weight_view,
154+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
155+ device::device_triplet_collection_types::const_view triplet_view) {
156+ vecmem::device_vector<scalar> highest_weights (highest_weight_view);
157+ const traccc::details::spacepoint_grid_types::const_device sp_grid (
158+ sp_grid_view);
159+ const device::device_triplet_collection_types::const_device triplets (
160+ triplet_view);
161+
162+ const unsigned int tid = blockIdx .x * blockDim .x + threadIdx .x ;
163+
164+ const unsigned int seed_idx = tid / 3 ;
165+ const unsigned int local_spacepoint_idx = tid % 3 ;
166+
167+ if (seed_idx >= triplets.size ()) {
168+ return ;
169+ }
170+
171+ unsigned int global_spacepoint_idx;
172+
173+ if (local_spacepoint_idx == 0 ) {
174+ global_spacepoint_idx = triplets.at (seed_idx).spB ;
175+ } else if (local_spacepoint_idx == 1 ) {
176+ global_spacepoint_idx = triplets.at (seed_idx).spM ;
177+ } else if (local_spacepoint_idx == 2 ) {
178+ global_spacepoint_idx = triplets.at (seed_idx).spT ;
179+ } else {
180+ __builtin_unreachable ();
181+ }
182+
183+ static_assert (sizeof (scalar) == 4 || sizeof (scalar) == 8 );
184+ using cas_type = std::conditional_t <sizeof (scalar) == 4 , unsigned int ,
185+ unsigned long long int >;
186+
187+ /*
188+ * The following is simply an implementation of atomic max.
189+ */
190+ scalar& weight_loc = highest_weights.at (global_spacepoint_idx);
191+ cas_type* weight_raw = reinterpret_cast <cas_type*>(&weight_loc);
192+
193+ vecmem::device_atomic_ref<cas_type> atomic (*weight_raw);
194+
195+ cas_type current_weight_raw = atomic.load ();
196+ scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
197+ const scalar own_weight = triplets.at (seed_idx).weight ;
198+ const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
199+
200+ while (own_weight > current_weight) {
201+ const bool res =
202+ atomic.compare_exchange_strong (current_weight_raw, own_weight_raw);
203+
204+ if (res) {
205+ return ;
206+ } else {
207+ current_weight = std::bit_cast<scalar>(current_weight_raw);
208+ }
209+ }
210+ }
151211} // namespace kernels
152212
153213namespace details {
@@ -184,6 +244,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184244 return {0 , m_mr.main };
185245 }
186246
247+ // Total number of spacepoints, including those not in the grid.
248+ const auto total_num_spacepoints = m_copy.get_size (spacepoints_view);
249+
187250 // Set up the doublet counter buffer.
188251 device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189252 num_spacepoints, m_mr.main , vecmem::data::buffer_type::resizable};
@@ -340,6 +403,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340403 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341404 triplet_buffer);
342405 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
406+ TRACCC_CUDA_ERROR_CHECK (cudaStreamSynchronize (stream));
407+
408+ vecmem::data::vector_buffer<scalar> highest_weight_buffer (0 , m_mr.main );
409+
410+ if (m_seedfilter_config.minNumTimesHighestWeight > 0 ) {
411+ highest_weight_buffer = {total_num_spacepoints, m_mr.main };
412+
413+ /*
414+ * NOTE: The value 0b11100000 is chosen because it, when repeated four
415+ * times to create a 32-bit floating point value, evaluates to the
416+ * number 0b11100000111000001110000011100000 which is very small, much
417+ * smaller than any seed weight should ever be. When broadcast eight
418+ * times to a 64-bit number it also produces an incredibly small value.
419+ */
420+ m_copy.memset (highest_weight_buffer, 0b11100000 )->wait ();
421+
422+ const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets ;
423+ const unsigned int num_threads = 256 ;
424+ const unsigned int num_blocks =
425+ (num_votes + num_threads - 1 ) / num_threads;
426+
427+ kernels::collect_highest_weight_per_spacepoint<<<
428+ num_blocks, num_threads, 0 , stream>>> (highest_weight_buffer,
429+ g2_view, triplet_buffer);
430+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
431+ }
343432
344433 // Create result object: collection of seeds
345434 edm::seed_collection::buffer seed_buffer (
@@ -362,7 +451,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362451 stream>>> (
363452 m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364453 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365- triplet_buffer, seed_buffer);
454+ triplet_buffer, highest_weight_buffer, seed_buffer);
366455 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
367456
368457 return seed_buffer;
0 commit comments