2323#include " traccc/seeding/device/count_doublets.hpp"
2424#include " traccc/seeding/device/count_triplets.hpp"
2525#include " traccc/seeding/device/find_doublets.hpp"
26+ #include " traccc/seeding/device/find_triplet_confirmations.hpp"
2627#include " traccc/seeding/device/find_triplets.hpp"
2728#include " traccc/seeding/device/reduce_triplet_counts.hpp"
2829#include " traccc/seeding/device/select_seeds.hpp"
29- #include " traccc/seeding/device/update_triplet_weights.hpp"
3030
3131// VecMem include(s).
3232#include < vecmem/utils/cuda/copy.hpp>
@@ -108,22 +108,23 @@ __global__ void find_triplets(
108108}
109109
110110// / CUDA kernel for running @c traccc::device::update_triplet_weights
111- __global__ void update_triplet_weights (
111+ __global__ void find_triplet_confirmations (
112112 seedfilter_config filter_config,
113113 edm::spacepoint_collection::const_view spacepoints,
114114 device::triplet_counter_spM_collection_types::const_view spM_tc,
115115 device::triplet_counter_collection_types::const_view midBot_tc,
116- device::device_triplet_collection_types::view triplet_view) {
116+ device::device_triplet_collection_types::view triplet_view,
117+ vecmem::data::vector_view<unsigned int > num_confirmations_view) {
117118
118119 // Array for temporary storage of quality parameters for comparing triplets
119120 // within weight updating kernel
120121 extern __shared__ scalar data[];
121122 // Each thread uses compatSeedLimit elements of the array
122123 scalar* dataPos = &data[threadIdx .x * filter_config.compatSeedLimit ];
123124
124- device::update_triplet_weights (details::global_index1 (), filter_config,
125- spacepoints, spM_tc, midBot_tc, dataPos,
126- triplet_view);
125+ device::find_triplet_confirmations (details::global_index1 (), filter_config,
126+ spacepoints, spM_tc, midBot_tc, dataPos,
127+ triplet_view, num_confirmations_view );
127128}
128129
129130// / CUDA kernel for running @c traccc::device::select_seeds
@@ -134,6 +135,7 @@ __global__ void select_seeds(
134135 device::triplet_counter_spM_collection_types::const_view spM_tc,
135136 device::triplet_counter_collection_types::const_view midBot_tc,
136137 device::device_triplet_collection_types::view triplet_view,
138+ const vecmem::data::vector_view<const unsigned int > num_confirmations_view,
137139 edm::seed_collection::view seed_view) {
138140
139141 // Array for temporary storage of triplets for comparing within seed
@@ -144,8 +146,8 @@ __global__ void select_seeds(
144146 &data2[threadIdx .x * filter_config.max_triplets_per_spM ];
145147
146148 device::select_seeds (details::global_index1 (), filter_config, spacepoints,
147- sp_view, spM_tc, midBot_tc, triplet_view, dataPos,
148- seed_view);
149+ sp_view, spM_tc, midBot_tc, triplet_view,
150+ num_confirmations_view, dataPos, seed_view);
149151}
150152
151153} // namespace kernels
@@ -324,6 +326,10 @@ edm::seed_collection::buffer seed_finding::operator()(
324326 triplet_buffer);
325327 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
326328
329+ vecmem::data::vector_buffer<unsigned int > num_confirmations_buffer (
330+ globalCounter_host->m_nTriplets , m_mr.main );
331+ m_copy.setup (num_confirmations_buffer)->wait ();
332+
327333 // Calculate the number of threads and thread blocks to run the weight
328334 // updating kernel for.
329335 const unsigned int nWeightUpdatingThreads = m_warp_size * 2 ;
@@ -332,13 +338,13 @@ edm::seed_collection::buffer seed_finding::operator()(
332338 nWeightUpdatingThreads;
333339
334340 // Update the weights of all spacepoint triplets.
335- kernels::update_triplet_weights <<<
341+ kernels::find_triplet_confirmations <<<
336342 nWeightUpdatingBlocks, nWeightUpdatingThreads,
337343 sizeof (scalar) * m_seedfilter_config.compatSeedLimit *
338344 nWeightUpdatingThreads,
339345 stream>>> (m_seedfilter_config, spacepoints_view,
340346 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
341- triplet_buffer);
347+ triplet_buffer, num_confirmations_buffer );
342348 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
343349
344350 // Create result object: collection of seeds
@@ -359,10 +365,10 @@ edm::seed_collection::buffer seed_finding::operator()(
359365 sizeof (device::device_triplet) *
360366 m_seedfilter_config.max_triplets_per_spM *
361367 nSeedSelectingThreads,
362- stream>>> (m_seedfilter_config, spacepoints_view,
363- g2_view, triplet_counter_spM_buffer ,
364- triplet_counter_midBot_buffer,
365- triplet_buffer, seed_buffer);
368+ stream>>> (
369+ m_seedfilter_config, spacepoints_view, g2_view ,
370+ triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
371+ triplet_buffer, num_confirmations_buffer , seed_buffer);
366372 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
367373
368374 return seed_buffer;
0 commit comments