2020#include " traccc/edm/device/doublet_counter.hpp"
2121#include " traccc/edm/device/seeding_global_counter.hpp"
2222#include " traccc/edm/device/triplet_counter.hpp"
23+ #include " traccc/seeding/detail/spacepoint_type.hpp"
2324#include " traccc/seeding/device/count_doublets.hpp"
2425#include " traccc/seeding/device/count_triplets.hpp"
2526#include " traccc/seeding/device/find_doublets.hpp"
@@ -65,20 +66,93 @@ __global__ void find_doublets(
6566 sp_grid, doublet_counter, mb_doublets, mt_doublets);
6667}
6768
69+ __global__ void make_mid_bot_lincircles (
70+ device::device_doublet_collection_types::const_view mb_doublet_view,
71+ device::doublet_counter_collection_types::const_view doublet_count_view,
72+ edm::spacepoint_collection::const_view spacepoint_view,
73+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
74+ vecmem::data::vector_view<lin_circle> out_view) {
75+
76+ const device::device_doublet_collection_types::const_device doublets (
77+ mb_doublet_view);
78+ const device::doublet_counter_collection_types::const_device doublet_counts (
79+ doublet_count_view);
80+ const edm::spacepoint_collection::const_device spacepoints (spacepoint_view);
81+ traccc::details::spacepoint_grid_types::const_device sp_grid (sp_grid_view);
82+ vecmem::device_vector<lin_circle> out (out_view);
83+
84+ unsigned int tid = blockIdx .x * blockDim .x + threadIdx .x ;
85+
86+ if (tid >= doublets.size ()) {
87+ return ;
88+ }
89+
90+ const device::device_doublet dub = doublets.at (tid);
91+ const unsigned int counter_link = dub.counter_link ;
92+ const device::doublet_counter count = doublet_counts.at (counter_link);
93+ const sp_location spM_loc = count.m_spM ;
94+ const edm::spacepoint_collection::const_device::const_proxy_type spM =
95+ spacepoints.at (sp_grid.bin (spM_loc.bin_idx )[spM_loc.sp_idx ]);
96+ const sp_location spB_loc = dub.sp2 ;
97+ const edm::spacepoint_collection::const_device::const_proxy_type spB =
98+ spacepoints.at (sp_grid.bin (spB_loc.bin_idx )[spB_loc.sp_idx ]);
99+
100+ out.at (tid) = doublet_finding_helper::transform_coordinates<
101+ traccc::details::spacepoint_type::bottom>(spM, spB);
102+ }
103+
104+ __global__ void make_mid_top_lincircles (
105+ device::device_doublet_collection_types::const_view mt_doublet_view,
106+ device::doublet_counter_collection_types::const_view doublet_count_view,
107+ edm::spacepoint_collection::const_view spacepoint_view,
108+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
109+ vecmem::data::vector_view<lin_circle> out_view) {
110+
111+ const device::device_doublet_collection_types::const_device doublets (
112+ mt_doublet_view);
113+ const device::doublet_counter_collection_types::const_device doublet_counts (
114+ doublet_count_view);
115+ const edm::spacepoint_collection::const_device spacepoints (spacepoint_view);
116+ traccc::details::spacepoint_grid_types::const_device sp_grid (sp_grid_view);
117+ vecmem::device_vector<lin_circle> out (out_view);
118+
119+ unsigned int tid = blockIdx .x * blockDim .x + threadIdx .x ;
120+
121+ if (tid >= doublets.size ()) {
122+ return ;
123+ }
124+
125+ const device::device_doublet dub = doublets.at (tid);
126+ const unsigned int counter_link = dub.counter_link ;
127+ const device::doublet_counter count = doublet_counts.at (counter_link);
128+ const sp_location spM_loc = count.m_spM ;
129+ const edm::spacepoint_collection::const_device::const_proxy_type spM =
130+ spacepoints.at (sp_grid.bin (spM_loc.bin_idx )[spM_loc.sp_idx ]);
131+ const sp_location spT_loc = dub.sp2 ;
132+ const edm::spacepoint_collection::const_device::const_proxy_type spT =
133+ spacepoints.at (sp_grid.bin (spT_loc.bin_idx )[spT_loc.sp_idx ]);
134+
135+ out.at (tid) = doublet_finding_helper::transform_coordinates<
136+ traccc::details::spacepoint_type::top>(spM, spT);
137+ }
138+
68139// / CUDA kernel for running @c traccc::device::count_triplets
69- __global__ void count_triplets (
140+ __global__ __launch_bounds__ ( 128 ) void count_triplets(
70141 seedfinder_config config,
71142 edm::spacepoint_collection::const_view spacepoints,
72143 traccc::details::spacepoint_grid_types::const_view sp_grid,
73144 device::doublet_counter_collection_types::const_view doublet_counter,
74145 device::device_doublet_collection_types::const_view mb_doublets,
75146 device::device_doublet_collection_types::const_view mt_doublets,
76147 device::triplet_counter_spM_collection_types::view spM_counter,
77- device::triplet_counter_collection_types::view midBot_counter) {
148+ device::triplet_counter_collection_types::view midBot_counter,
149+ vecmem::data::vector_view<const lin_circle> midBot_circles,
150+ vecmem::data::vector_view<const lin_circle> midTop_circles) {
78151
79152 device::count_triplets (details::global_index1 (), config, spacepoints,
80153 sp_grid, doublet_counter, mb_doublets, mt_doublets,
81- spM_counter, midBot_counter);
154+ spM_counter, midBot_counter, midBot_circles,
155+ midTop_circles);
82156}
83157
84158// / CUDA kernel for running @c traccc::device::reduce_triplet_counts
@@ -92,19 +166,22 @@ __global__ void reduce_triplet_counts(
92166}
93167
94168// / CUDA kernel for running @c traccc::device::find_triplets
95- __global__ void find_triplets (
169+ __global__ __launch_bounds__ ( 128 ) void find_triplets(
96170 seedfinder_config config, seedfilter_config filter_config,
97171 edm::spacepoint_collection::const_view spacepoints,
98172 traccc::details::spacepoint_grid_types::const_view sp_grid,
99173 device::doublet_counter_collection_types::const_view doublet_counter,
100174 device::device_doublet_collection_types::const_view mt_doublets,
101175 device::triplet_counter_spM_collection_types::const_view spM_tc,
102176 device::triplet_counter_collection_types::const_view midBot_tc,
177+ vecmem::data::vector_view<const lin_circle> midBot_circles,
178+ vecmem::data::vector_view<const lin_circle> midTop_circles,
103179 device::device_triplet_collection_types::view triplet_view) {
104180
105181 device::find_triplets (details::global_index1 (), config, filter_config,
106182 spacepoints, sp_grid, doublet_counter, mt_doublets,
107- spM_tc, midBot_tc, triplet_view);
183+ spM_tc, midBot_tc, midBot_circles, midTop_circles,
184+ triplet_view);
108185}
109186
110187// / CUDA kernel for running @c traccc::device::update_triplet_weights
@@ -252,6 +329,33 @@ edm::seed_collection::buffer seed_finding::operator()(
252329 doublet_counter_buffer, doublet_buffer_mb, doublet_buffer_mt);
253330 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
254331
332+ vecmem::data::vector_buffer<lin_circle> midBotLinCircles{
333+ globalCounter_host->m_nMidBot , m_mr.main };
334+ m_copy.setup (midBotLinCircles)->wait ();
335+ vecmem::data::vector_buffer<lin_circle> midTopLinCircles{
336+ globalCounter_host->m_nMidTop , m_mr.main };
337+ m_copy.setup (midBotLinCircles)->wait ();
338+
339+ {
340+ const unsigned int nThreads = 128 ;
341+ const unsigned int nMidBotBlocks =
342+ (globalCounter_host->m_nMidBot + nThreads - 1 ) / nThreads;
343+ const unsigned int nMidTopBlocks =
344+ (globalCounter_host->m_nMidTop + nThreads - 1 ) / nThreads;
345+
346+ kernels::
347+ make_mid_bot_lincircles<<<nMidBotBlocks, nThreads, 0 , stream>>> (
348+ doublet_buffer_mb, doublet_counter_buffer, spacepoints_view,
349+ g2_view, midBotLinCircles);
350+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
351+
352+ kernels::
353+ make_mid_top_lincircles<<<nMidTopBlocks, nThreads, 0 , stream>>> (
354+ doublet_buffer_mt, doublet_counter_buffer, spacepoints_view,
355+ g2_view, midTopLinCircles);
356+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
357+ }
358+
255359 // Set up the triplet counter buffers
256360 device::triplet_counter_spM_collection_types::buffer
257361 triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main };
@@ -265,7 +369,7 @@ edm::seed_collection::buffer seed_finding::operator()(
265369
266370 // Calculate the number of threads and thread blocks to run the doublet
267371 // counting kernel for.
268- const unsigned int nTripletCountThreads = m_warp_size * 2 ;
372+ const unsigned int nTripletCountThreads = 128 ;
269373 const unsigned int nTripletCountBlocks =
270374 (globalCounter_host->m_nMidBot + nTripletCountThreads - 1 ) /
271375 nTripletCountThreads;
@@ -275,7 +379,7 @@ edm::seed_collection::buffer seed_finding::operator()(
275379 stream>>> (
276380 m_seedfinder_config, spacepoints_view, g2_view, doublet_counter_buffer,
277381 doublet_buffer_mb, doublet_buffer_mt, triplet_counter_spM_buffer,
278- triplet_counter_midBot_buffer);
382+ triplet_counter_midBot_buffer, midBotLinCircles, midTopLinCircles );
279383 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
280384
281385 // Calculate the number of threads and thread blocks to run the triplet
@@ -309,7 +413,7 @@ edm::seed_collection::buffer seed_finding::operator()(
309413
310414 // Calculate the number of threads and thread blocks to run the triplet
311415 // finding kernel for.
312- const unsigned int nTripletFindThreads = m_warp_size * 2 ;
416+ const unsigned int nTripletFindThreads = 128 ;
313417 const unsigned int nTripletFindBlocks =
314418 (m_copy.get_size (triplet_counter_midBot_buffer) + nTripletFindThreads -
315419 1 ) /
@@ -321,7 +425,7 @@ edm::seed_collection::buffer seed_finding::operator()(
321425 m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
322426 doublet_counter_buffer, doublet_buffer_mt,
323427 triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
324- triplet_buffer);
428+ midBotLinCircles, midTopLinCircles, triplet_buffer);
325429 TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
326430
327431 // Calculate the number of threads and thread blocks to run the weight
0 commit comments