2020#include " traccc/edm/device/doublet_counter.hpp"
2121#include " traccc/edm/device/seeding_global_counter.hpp"
2222#include " traccc/edm/device/triplet_counter.hpp"
23+ #include " traccc/seeding/detail/singlet.hpp"
2324#include " traccc/seeding/device/count_doublets.hpp"
2425#include " traccc/seeding/device/count_triplets.hpp"
2526#include " traccc/seeding/device/find_doublets.hpp"
3738
3839namespace traccc ::cuda {
3940namespace kernels {
41+ namespace detail {
42+ /* *
43+ * @brief Encode the state of our parameter insertion mutex.
44+ */
45+ TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex (const bool locked,
46+ const uint32_t size,
47+ const float max) {
48+
49+ // Assert that the MSB of the size is zero
50+ assert (size <= 0x7FFFFFFF );
51+ const uint32_t hi = size | (locked ? 0x80000000 : 0x0 );
52+ const uint32_t lo = std::bit_cast<uint32_t >(max);
53+ return (static_cast <uint64_t >(hi) << 32 ) | lo;
54+ }
55+
56+ /* *
57+ * @brief Decode the state of our parameter insertion mutex.
58+ */
59+ TRACCC_HOST_DEVICE inline std::tuple<bool , uint32_t , float >
60+ decode_insertion_mutex (const uint64_t val) {
61+ const uint32_t hi = static_cast <uint32_t >(val >> 32 );
62+ const uint32_t lo = val & 0xFFFFFFFF ;
63+ return {static_cast <bool >(hi & 0x80000000 ), (hi & 0x7FFFFFFF ),
64+ std::bit_cast<float >(lo)};
65+ }
66+ } // namespace detail
67+
68+ __global__ inline void gather_best_triplets_per_spacepoint (
69+ const device::device_triplet_collection_types::const_view triplet_view,
70+ const traccc::details::spacepoint_grid_types::const_view grid_view,
71+ vecmem::data::vector_view<unsigned long long int > insertion_mutex_view,
72+ vecmem::data::vector_view<unsigned int > triplet_index_view,
73+ vecmem::data::vector_view<float > triplet_weight_view,
74+ const unsigned int max_num_triplets_per_spacepoint) {
75+
76+ const device::device_triplet_collection_types::const_device triplets (
77+ triplet_view);
78+ const traccc::details::spacepoint_grid_types::const_device grid (grid_view);
79+
80+ vecmem::device_vector<unsigned long long int > insertion_mutex (
81+ insertion_mutex_view);
82+ vecmem::device_vector<unsigned int > triplet_index (triplet_index_view);
83+ vecmem::device_vector<float > triplet_weight (triplet_weight_view);
84+
85+ unsigned int triplet_idx = blockIdx .x * blockDim .x + threadIdx .x ;
86+
87+ scalar weight;
88+ bool need_to_write = true ;
89+ unsigned int current_state = 0 ;
90+ unsigned int spacepoint_idx = 0 ;
91+
92+ auto get_idx_for_state = [&triplets, &grid,
93+ triplet_idx](unsigned int state) {
94+ const device::device_triplet& this_triplet = triplets.at (triplet_idx);
95+ if (state == 0 ) {
96+ return this_triplet.spB ;
97+ } else if (state == 1 ) {
98+ return this_triplet.spM ;
99+ } else {
100+ return this_triplet.spT ;
101+ }
102+ };
103+
104+ if (triplet_idx < triplets.size ()) {
105+ weight = triplets.at (triplet_idx).weight ;
106+ spacepoint_idx = get_idx_for_state (0 );
107+ } else {
108+ need_to_write = false ;
109+ current_state = 3 ;
110+ }
111+
112+ while (__syncthreads_or (current_state < 3 || need_to_write)) {
113+ if (current_state < 3 || need_to_write) {
114+ if (need_to_write) {
115+ vecmem::device_atomic_ref<unsigned long long int > mutex (
116+ insertion_mutex.at (spacepoint_idx));
117+
118+ unsigned long long int assumed = mutex.load ();
119+ unsigned long long int desired_set;
120+
121+ auto [locked, size, worst] =
122+ detail::decode_insertion_mutex (assumed);
123+
124+ if (need_to_write && size >= max_num_triplets_per_spacepoint &&
125+ weight <= worst) {
126+ need_to_write = false ;
127+ }
128+
129+ bool holds_lock = false ;
130+
131+ if (need_to_write && !locked) {
132+ desired_set =
133+ detail::encode_insertion_mutex (true , size, worst);
134+ if (mutex.compare_exchange_strong (assumed, desired_set)) {
135+ holds_lock = true ;
136+ }
137+ }
138+
139+ if (holds_lock) {
140+ unsigned int new_size;
141+ unsigned int offset =
142+ spacepoint_idx * max_num_triplets_per_spacepoint;
143+ unsigned int out_idx;
144+
145+ if (size == max_num_triplets_per_spacepoint) {
146+ new_size = size;
147+ scalar worst_weight =
148+ std::numeric_limits<scalar>::max ();
149+ for (unsigned int i = 0 ; i < size; ++i) {
150+ if (triplet_weight.at (offset + i) < worst_weight) {
151+ worst_weight = triplet_weight.at (offset + i);
152+ out_idx = i;
153+ }
154+ }
155+ } else {
156+ new_size = size + 1 ;
157+ out_idx = size;
158+ }
159+
160+ triplet_index.at (offset + out_idx) = triplet_idx;
161+ triplet_weight.at (offset + out_idx) = weight;
162+
163+ scalar new_worst = std::numeric_limits<scalar>::max ();
164+
165+ for (unsigned int i = 0 ; i < new_size; ++i) {
166+ new_worst =
167+ std::min (new_worst, triplet_weight.at (offset + i));
168+ }
169+
170+ [[maybe_unused]] bool cas_result =
171+ mutex.compare_exchange_strong (
172+ desired_set, detail::encode_insertion_mutex (
173+ false , new_size, new_worst));
174+ assert (cas_result);
175+
176+ need_to_write = false ;
177+ }
178+ }
179+
180+ if (!need_to_write && current_state < 3 ) {
181+ if (current_state < 2 ) {
182+ spacepoint_idx = get_idx_for_state (++current_state);
183+ need_to_write = true ;
184+ } else {
185+ ++current_state;
186+ }
187+ }
188+ }
189+ }
190+ }
191+
192+ __global__ inline void gather_spacepoint_votes (
193+ const vecmem::data::vector_view<const unsigned long long int >
194+ insertion_mutex_view,
195+ const vecmem::data::vector_view<const unsigned int > triplet_index_view,
196+ vecmem::data::vector_view<unsigned int > votes_per_triplet_view,
197+ const unsigned int max_num_triplets_per_spacepoint) {
198+
199+ unsigned int thread_idx = blockIdx .x * blockDim .x + threadIdx .x ;
200+ unsigned int spacepoint_idx = thread_idx / max_num_triplets_per_spacepoint;
201+ unsigned int triplet_idx = thread_idx % max_num_triplets_per_spacepoint;
202+
203+ const vecmem::device_vector<const unsigned long long int > insertion_mutex (
204+ insertion_mutex_view);
205+ const vecmem::device_vector<const unsigned int > triplet_index (
206+ triplet_index_view);
207+ vecmem::device_vector<unsigned int > votes_per_triplet (
208+ votes_per_triplet_view);
209+
210+ if (spacepoint_idx >= insertion_mutex.size ()) {
211+ return ;
212+ }
213+
214+ auto [locked, size, worst] =
215+ detail::decode_insertion_mutex (insertion_mutex.at (spacepoint_idx));
216+
217+ if (size == 0 ) {
218+ return ;
219+ }
220+
221+ unsigned int num_votes =
222+ (size + max_num_triplets_per_spacepoint - 1 ) / size;
223+
224+ if (triplet_idx < size) {
225+ vecmem::device_atomic_ref<unsigned int >(
226+ votes_per_triplet.at (triplet_index.at (thread_idx)))
227+ .fetch_add (num_votes);
228+ }
229+ }
40230
41231// / CUDA kernel for running @c traccc::device::count_doublets
42232__global__ void count_doublets (
@@ -134,6 +324,7 @@ __global__ void select_seeds(
134324 device::triplet_counter_spM_collection_types::const_view spM_tc,
135325 device::triplet_counter_collection_types::const_view midBot_tc,
136326 device::device_triplet_collection_types::view triplet_view,
327+ const vecmem::data::vector_view<const unsigned int > votes_per_triplet_view,
137328 edm::seed_collection::view seed_view) {
138329
139330 // Array for temporary storage of triplets for comparing within seed
@@ -145,7 +336,7 @@ __global__ void select_seeds(
145336
146337 device::select_seeds (details::global_index1 (), finder_config, filter_config,
147338 spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148- dataPos, seed_view);
339+ dataPos, votes_per_triplet_view, seed_view);
149340}
150341
151342} // namespace kernels
@@ -184,6 +375,10 @@ edm::seed_collection::buffer seed_finding::operator()(
184375 return {0 , m_mr.main };
185376 }
186377
378+ // Number of spacepoints including those not in the grid.
379+ const unsigned int num_spacepoints_total =
380+ m_copy.get_size (spacepoints_view);
381+
187382 // Set up the doublet counter buffer.
188383 device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189384 num_spacepoints, m_mr.main , vecmem::data::buffer_type::resizable};
@@ -341,6 +536,64 @@ edm::seed_collection::buffer seed_finding::operator()(
341536 triplet_buffer);
342537 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
343538
539+ vecmem::data::vector_buffer<unsigned int > votes_per_triplet_buffer (
540+ globalCounter_host->m_nTriplets , m_mr.main );
541+ m_copy.setup (votes_per_triplet_buffer)->wait ();
542+ m_copy.memset (votes_per_triplet_buffer, 0 )->wait ();
543+
544+ {
545+ unsigned int num_votes_per_sp = 3 ;
546+
547+ vecmem::data::vector_buffer<unsigned int >
548+ best_triplets_per_spacepoint_index_buffer (
549+ num_votes_per_sp * num_spacepoints_total, m_mr.main );
550+ m_copy.setup (best_triplets_per_spacepoint_index_buffer)->wait ();
551+
552+ vecmem::data::vector_buffer<unsigned long long int >
553+ best_triplets_per_spacepoint_insertion_mutex_buffer (
554+ num_spacepoints_total, m_mr.main );
555+ m_copy.setup (best_triplets_per_spacepoint_insertion_mutex_buffer)
556+ ->wait ();
557+ m_copy.memset (best_triplets_per_spacepoint_insertion_mutex_buffer, 0 )
558+ ->wait ();
559+
560+ {
561+ vecmem::data::vector_buffer<float >
562+ best_triplets_per_spacepoint_weight_buffer (
563+ num_votes_per_sp * num_spacepoints_total, m_mr.main );
564+ m_copy.setup (best_triplets_per_spacepoint_weight_buffer)->wait ();
565+
566+ unsigned int num_threads = 64 ;
567+ unsigned int num_blocks =
568+ (globalCounter_host->m_nTriplets + num_threads - 1 ) /
569+ num_threads;
570+
571+ kernels::gather_best_triplets_per_spacepoint<<<
572+ num_blocks, num_threads, 0 , stream>>> (
573+ triplet_buffer, g2_view,
574+ best_triplets_per_spacepoint_insertion_mutex_buffer,
575+ best_triplets_per_spacepoint_index_buffer,
576+ best_triplets_per_spacepoint_weight_buffer, num_votes_per_sp);
577+
578+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
579+ }
580+
581+ {
582+ unsigned int num_threads = 512 ;
583+ unsigned int num_blocks =
584+ (num_votes_per_sp * num_spacepoints_total + num_threads - 1 ) /
585+ num_threads;
586+
587+ kernels::
588+ gather_spacepoint_votes<<<num_blocks, num_threads, 0 , stream>>> (
589+ best_triplets_per_spacepoint_insertion_mutex_buffer,
590+ best_triplets_per_spacepoint_index_buffer,
591+ votes_per_triplet_buffer, num_votes_per_sp);
592+
593+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
594+ }
595+ }
596+
344597 // Create result object: collection of seeds
345598 edm::seed_collection::buffer seed_buffer (
346599 globalCounter_host->m_nTriplets , m_mr.main ,
@@ -362,7 +615,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362615 stream>>> (
363616 m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364617 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365- triplet_buffer, seed_buffer);
618+ triplet_buffer, votes_per_triplet_buffer, seed_buffer);
366619 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
367620
368621 return seed_buffer;
0 commit comments