2424#include " traccc/seeding/device/count_triplets.hpp"
2525#include " traccc/seeding/device/find_doublets.hpp"
2626#include " traccc/seeding/device/find_triplets.hpp"
27+ #include " traccc/seeding/device/make_mid_bot_lincircles.hpp"
28+ #include " traccc/seeding/device/make_mid_top_lincircles.hpp"
2729#include " traccc/seeding/device/reduce_triplet_counts.hpp"
2830#include " traccc/seeding/device/select_seeds.hpp"
2931#include " traccc/seeding/device/update_triplet_weights.hpp"
@@ -71,6 +73,42 @@ struct FindDoublets {
7173 }
7274};
7375
76+ // Kernel for running @c traccc::device::make_mid_bot_lincircles
77+ struct MakeMidBotLinCircles {
78+ template <typename TAcc>
79+ ALPAKA_FN_ACC void operator ()(
80+ TAcc const & acc,
81+ device::device_doublet_collection_types::const_view mb_doublet_view,
82+ device::doublet_counter_collection_types::const_view doublet_count_view,
83+ edm::spacepoint_collection::const_view spacepoint_view,
84+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
85+ vecmem::data::vector_view<lin_circle> out_view) const {
86+ auto const globalThreadIdx =
87+ ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u ];
88+ device::make_mid_bot_lincircles (globalThreadIdx, mb_doublet_view,
89+ doublet_count_view, spacepoint_view,
90+ sp_grid_view, out_view);
91+ }
92+ };
93+
94+ // Kernel for running @c traccc::device::make_mid_top_lincircles
95+ struct MakeMidTopLinCircles {
96+ template <typename TAcc>
97+ ALPAKA_FN_ACC void operator ()(
98+ TAcc const & acc,
99+ device::device_doublet_collection_types::const_view mt_doublet_view,
100+ device::doublet_counter_collection_types::const_view doublet_count_view,
101+ edm::spacepoint_collection::const_view spacepoint_view,
102+ traccc::details::spacepoint_grid_types::const_view sp_grid_view,
103+ vecmem::data::vector_view<lin_circle> out_view) const {
104+ auto const globalThreadIdx =
105+ ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u ];
106+ device::make_mid_top_lincircles (globalThreadIdx, mt_doublet_view,
107+ doublet_count_view, spacepoint_view,
108+ sp_grid_view, out_view);
109+ }
110+ };
111+
74112// Kernel for running @c traccc::device::count_triplets
75113struct CountTriplets {
76114 template <typename TAcc>
@@ -82,12 +120,15 @@ struct CountTriplets {
82120 device::device_doublet_collection_types::const_view mb_doublets,
83121 device::device_doublet_collection_types::const_view mt_doublets,
84122 device::triplet_counter_spM_collection_types::view spM_counter,
85- device::triplet_counter_collection_types::view midBot_counter) const {
123+ device::triplet_counter_collection_types::view midBot_counter,
124+ vecmem::data::vector_view<lin_circle> mb_circles,
125+ vecmem::data::vector_view<lin_circle> mt_circles) const {
86126 auto const globalThreadIdx =
87127 ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u ];
88128 device::count_triplets (globalThreadIdx, config, spacepoints, sp_grid,
89129 doublet_counter, mb_doublets, mt_doublets,
90- spM_counter, midBot_counter);
130+ spM_counter, midBot_counter, mb_circles,
131+ mt_circles);
91132 }
92133};
93134
@@ -118,13 +159,16 @@ struct FindTriplets {
118159 device::device_doublet_collection_types::const_view mt_doublets,
119160 device::triplet_counter_spM_collection_types::const_view spM_tc,
120161 device::triplet_counter_collection_types::const_view midBot_tc,
162+ vecmem::data::vector_view<lin_circle> mb_circles,
163+ vecmem::data::vector_view<lin_circle> mt_circles,
121164 device::device_triplet_collection_types::view triplet_view) const {
122165
123166 auto const globalThreadIdx =
124167 ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u ];
125168 device::find_triplets (globalThreadIdx, config, filter_config,
126169 spacepoints, sp_grid, doublet_counter,
127- mt_doublets, spM_tc, midBot_tc, triplet_view);
170+ mt_doublets, spM_tc, midBot_tc, mb_circles,
171+ mt_circles, triplet_view);
128172 }
129173};
130174
@@ -288,6 +332,37 @@ edm::seed_collection::buffer seed_finding::operator()(
288332 vecmem::get_data (doublet_buffer_mb),
289333 vecmem::get_data (doublet_buffer_mt));
290334
335+ vecmem::data::vector_buffer<lin_circle> mid_bot_lin_circles{
336+ pBufHost_counter->m_nMidBot , m_mr.main };
337+ m_copy.setup (mid_bot_lin_circles)->wait ();
338+ vecmem::data::vector_buffer<lin_circle> mid_top_lin_circles{
339+ pBufHost_counter->m_nMidTop , m_mr.main };
340+ m_copy.setup (mid_top_lin_circles)->wait ();
341+
342+ {
343+ const unsigned int n_threads = 128 ;
344+ const unsigned int n_mid_bot_blocks =
345+ (pBufHost_counter->m_nMidBot + n_threads - 1 ) / n_threads;
346+ const unsigned int n_mid_top_blocks =
347+ (pBufHost_counter->m_nMidTop + n_threads - 1 ) / n_threads;
348+ const auto mid_bot_workdiv =
349+ makeWorkDiv<Acc>(n_mid_bot_blocks, n_threads);
350+ const auto mid_top_workdiv =
351+ makeWorkDiv<Acc>(n_mid_top_blocks, n_threads);
352+
353+ ::alpaka::exec<Acc>(
354+ queue, mid_bot_workdiv, kernels::MakeMidBotLinCircles{},
355+ vecmem::get_data (doublet_buffer_mb),
356+ vecmem::get_data (doublet_counter_buffer), spacepoints_view, g2_view,
357+ vecmem::get_data (mid_bot_lin_circles));
358+
359+ ::alpaka::exec<Acc>(
360+ queue, mid_top_workdiv, kernels::MakeMidTopLinCircles{},
361+ vecmem::get_data (doublet_buffer_mb),
362+ vecmem::get_data (doublet_counter_buffer), spacepoints_view, g2_view,
363+ vecmem::get_data (mid_top_lin_circles));
364+ }
365+
291366 // Set up the triplet counter buffers
292367 device::triplet_counter_spM_collection_types::buffer
293368 triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main };
@@ -311,7 +386,9 @@ edm::seed_collection::buffer seed_finding::operator()(
311386 vecmem::get_data (doublet_buffer_mb),
312387 vecmem::get_data (doublet_buffer_mt),
313388 vecmem::get_data (triplet_counter_spM_buffer),
314- vecmem::get_data (triplet_counter_midBot_buffer));
389+ vecmem::get_data (triplet_counter_midBot_buffer),
390+ vecmem::get_data (mid_bot_lin_circles),
391+ vecmem::get_data (mid_top_lin_circles));
315392
316393 // Calculate the number of threads and thread blocks to run the triplet
317394 // count reduction kernel for.
@@ -353,6 +430,8 @@ edm::seed_collection::buffer seed_finding::operator()(
353430 vecmem::get_data (doublet_buffer_mt),
354431 vecmem::get_data (triplet_counter_spM_buffer),
355432 vecmem::get_data (triplet_counter_midBot_buffer),
433+ vecmem::get_data (mid_bot_lin_circles),
434+ vecmem::get_data (mid_top_lin_circles),
356435 vecmem::get_data (triplet_buffer));
357436
358437 blocksPerGrid =
0 commit comments