@@ -135,6 +135,7 @@ __global__ void select_seeds(
135135 device::triplet_counter_spM_collection_types::const_view spM_tc,
136136 device::triplet_counter_collection_types::const_view midBot_tc,
137137 device::device_triplet_collection_types::view triplet_view,
138+ vecmem::data::vector_view<const scalar> highest_weight_view,
138139 edm::seed_collection::view seed_view) {
139140
140141 // Array for temporary storage of triplets for comparing within seed
@@ -144,10 +145,72 @@ __global__ void select_seeds(
144145 triplet* dataPos = &data2[threadIdx .x * filter_config.max_triplets_per_spM ];
145146
146147 device::select_seeds (details::global_index1 (), filter_config, spacepoints,
147- sp_view, spM_tc, midBot_tc, triplet_view, dataPos,
148- seed_view);
148+ sp_view, spM_tc, midBot_tc, triplet_view,
149+ highest_weight_view, dataPos, seed_view);
149150}
150151
152+ __global__ void collect_highest_weight_per_spacepoint (
153+ vecmem::data::vector_view<scalar> highest_weight_view,
154+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
155+ device::device_triplet_collection_types::const_view triplet_view) {
156+ vecmem::device_vector<scalar> highest_weights (highest_weight_view);
157+ const traccc::details::spacepoint_grid_types::const_device sp_grid (
158+ sp_grid_view);
159+ const device::device_triplet_collection_types::const_device triplets (
160+ triplet_view);
161+
162+ const unsigned int tid = blockIdx .x * blockDim .x + threadIdx .x ;
163+
164+ const unsigned int seed_idx = tid / 3 ;
165+ const unsigned int local_spacepoint_idx = tid % 3 ;
166+
167+ if (seed_idx >= triplets.size ()) {
168+ return ;
169+ }
170+
171+ sp_location global_spacepoint_loc;
172+
173+ if (local_spacepoint_idx == 0 ) {
174+ global_spacepoint_loc = triplets.at (seed_idx).spB ;
175+ } else if (local_spacepoint_idx == 1 ) {
176+ global_spacepoint_loc = triplets.at (seed_idx).spM ;
177+ } else if (local_spacepoint_idx == 2 ) {
178+ global_spacepoint_loc = triplets.at (seed_idx).spT ;
179+ } else {
180+ __builtin_unreachable ();
181+ }
182+
183+ const unsigned int global_spacepoint_idx = sp_grid.bin (
184+ global_spacepoint_loc.bin_idx )[global_spacepoint_loc.sp_idx ];
185+
186+ static_assert (sizeof (scalar) == 4 || sizeof (scalar) == 8 );
187+ using cas_type = std::conditional_t <sizeof (scalar) == 4 , unsigned int ,
188+ unsigned long long int >;
189+
190+ /*
191+ * The following is simply an implementation of atomic max.
192+ */
193+ scalar& weight_loc = highest_weights.at (global_spacepoint_idx);
194+ cas_type* weight_raw = reinterpret_cast <cas_type*>(&weight_loc);
195+
196+ vecmem::device_atomic_ref<cas_type> atomic (*weight_raw);
197+
198+ cas_type current_weight_raw = atomic.load ();
199+ scalar current_weight = std::bit_cast<scalar>(current_weight_raw);
200+ const scalar own_weight = triplets.at (seed_idx).weight ;
201+ const cas_type own_weight_raw = std::bit_cast<cas_type>(own_weight);
202+
203+ while (own_weight > current_weight) {
204+ const bool res =
205+ atomic.compare_exchange_strong (current_weight_raw, own_weight_raw);
206+
207+ if (res) {
208+ return ;
209+ } else {
210+ current_weight = std::bit_cast<scalar>(current_weight_raw);
211+ }
212+ }
213+ }
151214} // namespace kernels
152215
153216namespace details {
@@ -184,6 +247,9 @@ edm::seed_collection::buffer seed_finding::operator()(
184247 return {0 , m_mr.main };
185248 }
186249
250+ // Total number of spacepoints, including those not in the grid.
251+ const auto total_num_spacepoints = m_copy.get_size (spacepoints_view);
252+
187253 // Set up the doublet counter buffer.
188254 device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189255 num_spacepoints, m_mr.main , vecmem::data::buffer_type::resizable};
@@ -340,6 +406,32 @@ edm::seed_collection::buffer seed_finding::operator()(
340406 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341407 triplet_buffer);
342408 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
409+ TRACCC_CUDA_ERROR_CHECK (cudaStreamSynchronize (stream));
410+
411+ vecmem::data::vector_buffer<scalar> highest_weight_buffer (0 , m_mr.main );
412+
413+ if (m_seedfilter_config.minNumTimesHighestWeight > 0 ) {
414+ highest_weight_buffer = {total_num_spacepoints, m_mr.main };
415+
416+ /*
417+ * NOTE: The value 0b11100000 is chosen because it, when repeated four
418+ * times to create a 32-bit floating point value, evaluates to the
419+ * number 0b11100000111000001110000011100000 which is very small, much
420+ * smaller than any seed weight should ever be. When broadcast eight
421+ * times to a 64-bit number it also produces an incredibly small value.
422+ */
423+ m_copy.memset (highest_weight_buffer, 0b11100000 )->wait ();
424+
425+ const unsigned int num_votes = 3 * globalCounter_host->m_nTriplets ;
426+ const unsigned int num_threads = 256 ;
427+ const unsigned int num_blocks =
428+ (num_votes + num_threads - 1 ) / num_threads;
429+
430+ kernels::collect_highest_weight_per_spacepoint<<<
431+ num_blocks, num_threads, 0 , stream>>> (highest_weight_buffer,
432+ g2_view, triplet_buffer);
433+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
434+ }
343435
344436 // Create result object: collection of seeds
345437 edm::seed_collection::buffer seed_buffer (
@@ -359,10 +451,10 @@ edm::seed_collection::buffer seed_finding::operator()(
359451 sizeof (triplet) *
360452 m_seedfilter_config.max_triplets_per_spM *
361453 nSeedSelectingThreads,
362- stream>>> (m_seedfilter_config, spacepoints_view,
363- g2_view, triplet_counter_spM_buffer ,
364- triplet_counter_midBot_buffer,
365- triplet_buffer, seed_buffer);
454+ stream>>> (
455+ m_seedfilter_config, spacepoints_view, g2_view ,
456+ triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
457+ triplet_buffer, highest_weight_buffer , seed_buffer);
366458 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
367459
368460 return seed_buffer;
0 commit comments