diff --git a/device/alpaka/src/seeding/seed_finding.cpp b/device/alpaka/src/seeding/seed_finding.cpp index adf478b115..fd4f1c892d 100644 --- a/device/alpaka/src/seeding/seed_finding.cpp +++ b/device/alpaka/src/seeding/seed_finding.cpp @@ -24,6 +24,8 @@ #include "traccc/seeding/device/count_triplets.hpp" #include "traccc/seeding/device/find_doublets.hpp" #include "traccc/seeding/device/find_triplets.hpp" +#include "traccc/seeding/device/make_mid_bot_lincircles.hpp" +#include "traccc/seeding/device/make_mid_top_lincircles.hpp" #include "traccc/seeding/device/reduce_triplet_counts.hpp" #include "traccc/seeding/device/select_seeds.hpp" #include "traccc/seeding/device/update_triplet_weights.hpp" @@ -71,6 +73,42 @@ struct FindDoublets { } }; +// Kernel for running @c traccc::device::make_mid_bot_lincircles +struct MakeMidBotLinCircles { + template + ALPAKA_FN_ACC void operator()( + TAcc const& acc, + device::device_doublet_collection_types::const_view mb_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) const { + auto const globalThreadIdx = + ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u]; + device::make_mid_bot_lincircles(globalThreadIdx, mb_doublet_view, + doublet_count_view, spacepoint_view, + sp_grid_view, out_view); + } +}; + +// Kernel for running @c traccc::device::make_mid_top_lincircles +struct MakeMidTopLinCircles { + template + ALPAKA_FN_ACC void operator()( + TAcc const& acc, + device::device_doublet_collection_types::const_view mt_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) const { + auto const globalThreadIdx = + ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u]; + device::make_mid_top_lincircles(globalThreadIdx, mt_doublet_view, + doublet_count_view, spacepoint_view, + sp_grid_view, out_view); + } +}; + // Kernel for running @c traccc::device::count_triplets struct CountTriplets { template @@ -82,12 +120,15 @@ struct CountTriplets { device::device_doublet_collection_types::const_view mb_doublets, device::device_doublet_collection_types::const_view mt_doublets, device::triplet_counter_spM_collection_types::view spM_counter, - device::triplet_counter_collection_types::view midBot_counter) const { + device::triplet_counter_collection_types::view midBot_counter, + vecmem::data::vector_view mb_circles, + vecmem::data::vector_view mt_circles) const { auto const globalThreadIdx = ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u]; device::count_triplets(globalThreadIdx, config, spacepoints, sp_grid, doublet_counter, mb_doublets, mt_doublets, - spM_counter, midBot_counter); + spM_counter, midBot_counter, mb_circles, + mt_circles); } }; @@ -118,13 +159,16 @@ struct FindTriplets { device::device_doublet_collection_types::const_view mt_doublets, device::triplet_counter_spM_collection_types::const_view spM_tc, device::triplet_counter_collection_types::const_view midBot_tc, + vecmem::data::vector_view mb_circles, + vecmem::data::vector_view mt_circles, device::device_triplet_collection_types::view triplet_view) const { auto const globalThreadIdx = ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u]; device::find_triplets(globalThreadIdx, config, filter_config, spacepoints, sp_grid, doublet_counter, - mt_doublets, spM_tc, midBot_tc, triplet_view); + mt_doublets, spM_tc, midBot_tc, mb_circles, + mt_circles, triplet_view); } }; @@ -288,6 +332,37 @@ edm::seed_collection::buffer seed_finding::operator()( vecmem::get_data(doublet_buffer_mb), vecmem::get_data(doublet_buffer_mt)); + vecmem::data::vector_buffer mid_bot_lin_circles{ + pBufHost_counter->m_nMidBot, m_mr.main}; + m_copy.setup(mid_bot_lin_circles)->wait(); + vecmem::data::vector_buffer mid_top_lin_circles{ + pBufHost_counter->m_nMidTop, m_mr.main}; + m_copy.setup(mid_top_lin_circles)->wait(); + + { + const unsigned int n_threads = 128; + const unsigned int n_mid_bot_blocks = + (pBufHost_counter->m_nMidBot + n_threads - 1) / n_threads; + const unsigned int n_mid_top_blocks = + (pBufHost_counter->m_nMidTop + n_threads - 1) / n_threads; + const auto mid_bot_workdiv = + makeWorkDiv(n_mid_bot_blocks, n_threads); + const auto mid_top_workdiv = + makeWorkDiv(n_mid_top_blocks, n_threads); + + ::alpaka::exec( + queue, mid_bot_workdiv, kernels::MakeMidBotLinCircles{}, + vecmem::get_data(doublet_buffer_mb), + vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view, + vecmem::get_data(mid_bot_lin_circles)); + + ::alpaka::exec( + queue, mid_top_workdiv, kernels::MakeMidTopLinCircles{}, + vecmem::get_data(doublet_buffer_mb), + vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view, + vecmem::get_data(mid_top_lin_circles)); + } + // Set up the triplet counter buffers device::triplet_counter_spM_collection_types::buffer triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main}; @@ -311,7 +386,9 @@ edm::seed_collection::buffer seed_finding::operator()( vecmem::get_data(doublet_buffer_mb), vecmem::get_data(doublet_buffer_mt), vecmem::get_data(triplet_counter_spM_buffer), - vecmem::get_data(triplet_counter_midBot_buffer)); + vecmem::get_data(triplet_counter_midBot_buffer), + vecmem::get_data(mid_bot_lin_circles), + vecmem::get_data(mid_top_lin_circles)); // Calculate the number of threads and thread blocks to run the triplet // count reduction kernel for. @@ -353,6 +430,8 @@ edm::seed_collection::buffer seed_finding::operator()( vecmem::get_data(doublet_buffer_mt), vecmem::get_data(triplet_counter_spM_buffer), vecmem::get_data(triplet_counter_midBot_buffer), + vecmem::get_data(mid_bot_lin_circles), + vecmem::get_data(mid_top_lin_circles), vecmem::get_data(triplet_buffer)); blocksPerGrid = diff --git a/device/common/include/traccc/edm/device/triplet_counter.hpp b/device/common/include/traccc/edm/device/triplet_counter.hpp index de113a2567..5b716c027c 100644 --- a/device/common/include/traccc/edm/device/triplet_counter.hpp +++ b/device/common/include/traccc/edm/device/triplet_counter.hpp @@ -47,6 +47,9 @@ struct triplet_counter { /// The position in which these triplets will be added unsigned int posTriplets = 0; + /// Index of the bottom-middle doublet + unsigned int m_botMidIdx = 0; + }; // struct triplet_counter /// Declare all triplet counter collection types diff --git a/device/common/include/traccc/seeding/device/count_triplets.hpp b/device/common/include/traccc/seeding/device/count_triplets.hpp index c8dc40fde2..47e0b3c945 100644 --- a/device/common/include/traccc/seeding/device/count_triplets.hpp +++ b/device/common/include/traccc/seeding/device/count_triplets.hpp @@ -48,7 +48,9 @@ inline void count_triplets( const device_doublet_collection_types::const_view& mid_bot_doublet_view, const device_doublet_collection_types::const_view& mid_top_doublet_view, triplet_counter_spM_collection_types::view spM_tc, - triplet_counter_collection_types::view mb_tc); + triplet_counter_collection_types::view mb_tc, + vecmem::data::vector_view mid_bot_circles, + vecmem::data::vector_view mid_top_circles); } // namespace traccc::device diff --git a/device/common/include/traccc/seeding/device/find_triplets.hpp b/device/common/include/traccc/seeding/device/find_triplets.hpp index 647b6f7bdb..5090406b05 100644 --- a/device/common/include/traccc/seeding/device/find_triplets.hpp +++ b/device/common/include/traccc/seeding/device/find_triplets.hpp @@ -49,6 +49,8 @@ inline void find_triplets( const device_doublet_collection_types::const_view& mid_top_doublet_view, const triplet_counter_spM_collection_types::const_view& spM_tc_view, const triplet_counter_collection_types::const_view& tc_view, + vecmem::data::vector_view mid_bot_circle_view, + vecmem::data::vector_view mid_top_circle_view, device_triplet_collection_types::view triplet_view); } // namespace traccc::device diff --git a/device/common/include/traccc/seeding/device/impl/count_triplets.ipp b/device/common/include/traccc/seeding/device/impl/count_triplets.ipp index 53a1a5e219..8d263a1d9f 100644 --- a/device/common/include/traccc/seeding/device/impl/count_triplets.ipp +++ b/device/common/include/traccc/seeding/device/impl/count_triplets.ipp @@ -24,7 +24,9 @@ inline void count_triplets( const device_doublet_collection_types::const_view& mid_bot_doublet_view, const device_doublet_collection_types::const_view& mid_top_doublet_view, triplet_counter_spM_collection_types::view spM_tc_view, - triplet_counter_collection_types::view mb_tc_view) { + triplet_counter_collection_types::view mb_tc_view, + vecmem::data::vector_view mid_bot_circle_view, + vecmem::data::vector_view mid_top_circle_view) { // Create device copy of input parameters const device_doublet_collection_types::const_device mid_bot_doublet_device( @@ -43,6 +45,10 @@ inline void count_triplets( const device_doublet_collection_types::const_device mid_top_doublet_device( mid_top_doublet_view); const doublet_counter_collection_types::const_device dc_device(dc_view); + const vecmem::device_vector mid_bot_circles( + mid_bot_circle_view); + const vecmem::device_vector mid_top_circles( + mid_top_circle_view); // Create device copy of output parameterss triplet_counter_collection_types::device mb_triplet_counter(mb_tc_view); @@ -61,13 +67,9 @@ inline void count_triplets( const edm::spacepoint_collection::const_device::const_proxy_type spM = spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]); const sp_location spB_loc = mid_bot.sp2; - // bottom spacepoint - const edm::spacepoint_collection::const_device::const_proxy_type spB = - spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]); // Apply the conformal transformation to middle-bot doublet - traccc::lin_circle lb = doublet_finding_helper::transform_coordinates< - details::spacepoint_type::bottom>(spM, spB); + const traccc::lin_circle lb = mid_bot_circles.at(globalIndex); // Calculate some physical quantities required for triplet compatibility // check @@ -75,10 +77,6 @@ inline void count_triplets( scalar scatteringInRegion2 = config.maxScatteringAngle2 * iSinTheta2; scatteringInRegion2 *= config.sigmaScattering * config.sigmaScattering; - // These two quantities are used as output parameters in - // triplet_finding_helper::isCompatible but their values are irrelevant - scalar curvature, impact_parameter; - // find the reference (start) index of the mid-top doublet container // item vector, where the doublets are recorded const unsigned int mt_start_idx = doublet_counts.m_posMidTop; @@ -87,18 +85,16 @@ inline void count_triplets( // number of triplets per middle-bot doublet unsigned int num_triplets_per_mb = 0; - // iterate over mid-top doublets - for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) { - const traccc::sp_location spT_loc = mid_top_doublet_device[i].sp2; - - const edm::spacepoint_collection::const_device::const_proxy_type spT = - spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]); + const unsigned int num_mt = mt_end_idx - mt_start_idx; + // iterate over mid-top doublets + for (unsigned int ri = 0; ri < num_mt; ++ri) { // Apply the conformal transformation to middle-top doublet - traccc::lin_circle lt = doublet_finding_helper::transform_coordinates< - details::spacepoint_type::top>(spM, spT); + const lin_circle& lt = mid_top_circles.at(ri + mt_start_idx); - // Check if mid-bot and mid-top doublets can form a triplet + // These two quantities are used as output parameters in + // triplet_finding_helper::isCompatible but their values are irrelevant + scalar curvature, impact_parameter; if (triplet_finding_helper::isCompatible( spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature, impact_parameter)) { @@ -114,8 +110,9 @@ inline void count_triplets( const unsigned int posTriplets = nTriplets.fetch_add(num_triplets_per_mb); - mb_triplet_counter.push_back( - {spB_loc, counter_link, num_triplets_per_mb, posTriplets}); + mb_triplet_counter.push_back({spB_loc, counter_link, + num_triplets_per_mb, posTriplets, + globalIndex}); } } diff --git a/device/common/include/traccc/seeding/device/impl/find_triplets.ipp b/device/common/include/traccc/seeding/device/impl/find_triplets.ipp index 3475876154..e32add14af 100644 --- a/device/common/include/traccc/seeding/device/impl/find_triplets.ipp +++ b/device/common/include/traccc/seeding/device/impl/find_triplets.ipp @@ -25,6 +25,8 @@ inline void find_triplets( const device_doublet_collection_types::const_view& mid_top_doublet_view, const triplet_counter_spM_collection_types::const_view& spM_tc_view, const triplet_counter_collection_types::const_view& tc_view, + vecmem::data::vector_view mid_bot_circle_view, + vecmem::data::vector_view mid_top_circle_view, device_triplet_collection_types::view triplet_view) { // Check if anything needs to be done. @@ -44,6 +46,10 @@ inline void find_triplets( const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view); const triplet_counter_spM_collection_types::const_device triplet_counts_spM( spM_tc_view); + const vecmem::device_vector mid_bot_circles( + mid_bot_circle_view); + const vecmem::device_vector mid_top_circles( + mid_top_circle_view); // Get the current work item information const triplet_counter mid_bot_counter = triplet_counts.at(globalIndex); @@ -69,8 +75,8 @@ inline void find_triplets( device_triplet_collection_types::device triplets(triplet_view); // Apply the conformal transformation to middle-bot doublet - const traccc::lin_circle lb = doublet_finding_helper::transform_coordinates< - details::spacepoint_type::bottom>(spM, spB); + const traccc::lin_circle lb = + mid_bot_circles.at(mid_bot_counter.m_botMidIdx); // Calculate some physical quantities required for triplet compatibility // check @@ -79,10 +85,6 @@ inline void find_triplets( config.sigmaScattering * config.sigmaScattering; - // These two quantities are used as output parameters in - // triplet_finding_helper::isCompatible but their values are irrelevant - scalar curvature, impact_parameter; - // find the reference (start) index of the mid-top doublet collection // item vector, where the doublets are recorded const unsigned int mt_start_idx = doublet_count.m_posMidTop; @@ -95,19 +97,15 @@ inline void find_triplets( // iterate over mid-top doublets for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) { - const sp_location spT_loc = mid_top_doublet_device[i].sp2; + // Apply the conformal transformation to middle-top doublet + const traccc::lin_circle& lt = mid_top_circles.at(i); + const sp_location spT_loc = mid_top_doublet_device[i].sp2; const unsigned int spT_idx = sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]; - const edm::spacepoint_collection::const_device::const_proxy_type spT = - spacepoints.at(spT_idx); - - // Apply the conformal transformation to middle-top doublet - const traccc::lin_circle lt = - doublet_finding_helper::transform_coordinates< - details::spacepoint_type::top>(spM, spT); // Check if mid-bot and mid-top doublets can form a triplet + scalar curvature, impact_parameter; if (triplet_finding_helper::isCompatible( spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature, impact_parameter)) { diff --git a/device/common/include/traccc/seeding/device/make_mid_bot_lincircles.hpp b/device/common/include/traccc/seeding/device/make_mid_bot_lincircles.hpp new file mode 100644 index 0000000000..94e2589e05 --- /dev/null +++ b/device/common/include/traccc/seeding/device/make_mid_bot_lincircles.hpp @@ -0,0 +1,48 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +namespace traccc::device { +/** + * @brief Kernel to create middle-bottom linearised circles. + */ +TRACCC_HOST_DEVICE +inline void make_mid_bot_lincircles( + global_index_t tid, + device::device_doublet_collection_types::const_view mb_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) { + + const device::device_doublet_collection_types::const_device doublets( + mb_doublet_view); + const device::doublet_counter_collection_types::const_device doublet_counts( + doublet_count_view); + const edm::spacepoint_collection::const_device spacepoints(spacepoint_view); + traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view); + vecmem::device_vector out(out_view); + + if (tid >= doublets.size()) { + return; + } + + const device::device_doublet dub = doublets.at(tid); + const unsigned int counter_link = dub.counter_link; + const device::doublet_counter count = doublet_counts.at(counter_link); + const sp_location spM_loc = count.m_spM; + const edm::spacepoint_collection::const_device::const_proxy_type spM = + spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]); + const sp_location spB_loc = dub.sp2; + const edm::spacepoint_collection::const_device::const_proxy_type spB = + spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]); + + out.at(tid) = doublet_finding_helper::transform_coordinates< + traccc::details::spacepoint_type::bottom>(spM, spB); +} +} // namespace traccc::device diff --git a/device/common/include/traccc/seeding/device/make_mid_top_lincircles.hpp b/device/common/include/traccc/seeding/device/make_mid_top_lincircles.hpp new file mode 100644 index 0000000000..8530b528fb --- /dev/null +++ b/device/common/include/traccc/seeding/device/make_mid_top_lincircles.hpp @@ -0,0 +1,48 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +namespace traccc::device { +/** + * @brief Kernel to create middle-bottom linearised circles. + */ +TRACCC_HOST_DEVICE +inline void make_mid_top_lincircles( + global_index_t tid, + device::device_doublet_collection_types::const_view mt_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) { + + const device::device_doublet_collection_types::const_device doublets( + mt_doublet_view); + const device::doublet_counter_collection_types::const_device doublet_counts( + doublet_count_view); + const edm::spacepoint_collection::const_device spacepoints(spacepoint_view); + traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view); + vecmem::device_vector out(out_view); + + if (tid >= doublets.size()) { + return; + } + + const device::device_doublet dub = doublets.at(tid); + const unsigned int counter_link = dub.counter_link; + const device::doublet_counter count = doublet_counts.at(counter_link); + const sp_location spM_loc = count.m_spM; + const edm::spacepoint_collection::const_device::const_proxy_type spM = + spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]); + const sp_location spT_loc = dub.sp2; + const edm::spacepoint_collection::const_device::const_proxy_type spT = + spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]); + + out.at(tid) = doublet_finding_helper::transform_coordinates< + traccc::details::spacepoint_type::top>(spM, spT); +} +} // namespace traccc::device diff --git a/device/cuda/src/seeding/seed_finding.cu b/device/cuda/src/seeding/seed_finding.cu index 98d2f947ae..45fc56c796 100644 --- a/device/cuda/src/seeding/seed_finding.cu +++ b/device/cuda/src/seeding/seed_finding.cu @@ -20,10 +20,13 @@ #include "traccc/edm/device/doublet_counter.hpp" #include "traccc/edm/device/seeding_global_counter.hpp" #include "traccc/edm/device/triplet_counter.hpp" +#include "traccc/seeding/detail/spacepoint_type.hpp" #include "traccc/seeding/device/count_doublets.hpp" #include "traccc/seeding/device/count_triplets.hpp" #include "traccc/seeding/device/find_doublets.hpp" #include "traccc/seeding/device/find_triplets.hpp" +#include "traccc/seeding/device/make_mid_bot_lincircles.hpp" +#include "traccc/seeding/device/make_mid_top_lincircles.hpp" #include "traccc/seeding/device/reduce_triplet_counts.hpp" #include "traccc/seeding/device/select_seeds.hpp" #include "traccc/seeding/device/update_triplet_weights.hpp" @@ -65,8 +68,32 @@ __global__ void find_doublets( sp_grid, doublet_counter, mb_doublets, mt_doublets); } +__global__ void make_mid_bot_lincircles( + device::device_doublet_collection_types::const_view mb_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) { + + device::make_mid_bot_lincircles(details::global_index1(), mb_doublet_view, + doublet_count_view, spacepoint_view, + sp_grid_view, out_view); +} + +__global__ void make_mid_top_lincircles( + device::device_doublet_collection_types::const_view mt_doublet_view, + device::doublet_counter_collection_types::const_view doublet_count_view, + edm::spacepoint_collection::const_view spacepoint_view, + traccc::details::spacepoint_grid_types::const_view sp_grid_view, + vecmem::data::vector_view out_view) { + + device::make_mid_top_lincircles(details::global_index1(), mt_doublet_view, + doublet_count_view, spacepoint_view, + sp_grid_view, out_view); +} + /// CUDA kernel for running @c traccc::device::count_triplets -__global__ void count_triplets( +__global__ __launch_bounds__(128) void count_triplets( seedfinder_config config, edm::spacepoint_collection::const_view spacepoints, traccc::details::spacepoint_grid_types::const_view sp_grid, @@ -74,11 +101,14 @@ __global__ void count_triplets( device::device_doublet_collection_types::const_view mb_doublets, device::device_doublet_collection_types::const_view mt_doublets, device::triplet_counter_spM_collection_types::view spM_counter, - device::triplet_counter_collection_types::view midBot_counter) { + device::triplet_counter_collection_types::view midBot_counter, + vecmem::data::vector_view midBot_circles, + vecmem::data::vector_view midTop_circles) { device::count_triplets(details::global_index1(), config, spacepoints, sp_grid, doublet_counter, mb_doublets, mt_doublets, - spM_counter, midBot_counter); + spM_counter, midBot_counter, midBot_circles, + midTop_circles); } /// CUDA kernel for running @c traccc::device::reduce_triplet_counts @@ -92,7 +122,7 @@ __global__ void reduce_triplet_counts( } /// CUDA kernel for running @c traccc::device::find_triplets -__global__ void find_triplets( +__global__ __launch_bounds__(128) void find_triplets( seedfinder_config config, seedfilter_config filter_config, edm::spacepoint_collection::const_view spacepoints, traccc::details::spacepoint_grid_types::const_view sp_grid, @@ -100,11 +130,14 @@ __global__ void find_triplets( device::device_doublet_collection_types::const_view mt_doublets, device::triplet_counter_spM_collection_types::const_view spM_tc, device::triplet_counter_collection_types::const_view midBot_tc, + vecmem::data::vector_view midBot_circles, + vecmem::data::vector_view midTop_circles, device::device_triplet_collection_types::view triplet_view) { device::find_triplets(details::global_index1(), config, filter_config, spacepoints, sp_grid, doublet_counter, mt_doublets, - spM_tc, midBot_tc, triplet_view); + spM_tc, midBot_tc, midBot_circles, midTop_circles, + triplet_view); } /// CUDA kernel for running @c traccc::device::update_triplet_weights @@ -260,6 +293,33 @@ edm::seed_collection::buffer seed_finding::operator()( doublet_counter_buffer, doublet_buffer_mb, doublet_buffer_mt); TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); + vecmem::data::vector_buffer midBotLinCircles{ + globalCounter_host->m_nMidBot, m_mr.main}; + m_copy.setup(midBotLinCircles)->wait(); + vecmem::data::vector_buffer midTopLinCircles{ + globalCounter_host->m_nMidTop, m_mr.main}; + m_copy.setup(midBotLinCircles)->wait(); + + { + const unsigned int nThreads = 128; + const unsigned int nMidBotBlocks = + (globalCounter_host->m_nMidBot + nThreads - 1) / nThreads; + const unsigned int nMidTopBlocks = + (globalCounter_host->m_nMidTop + nThreads - 1) / nThreads; + + kernels:: + make_mid_bot_lincircles<<>>( + doublet_buffer_mb, doublet_counter_buffer, spacepoints_view, + g2_view, midBotLinCircles); + TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); + + kernels:: + make_mid_top_lincircles<<>>( + doublet_buffer_mt, doublet_counter_buffer, spacepoints_view, + g2_view, midTopLinCircles); + TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); + } + // Set up the triplet counter buffers device::triplet_counter_spM_collection_types::buffer triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main}; @@ -273,7 +333,7 @@ edm::seed_collection::buffer seed_finding::operator()( // Calculate the number of threads and thread blocks to run the doublet // counting kernel for. - const unsigned int nTripletCountThreads = m_warp_size * 2; + const unsigned int nTripletCountThreads = 128; const unsigned int nTripletCountBlocks = (globalCounter_host->m_nMidBot + nTripletCountThreads - 1) / nTripletCountThreads; @@ -283,7 +343,7 @@ edm::seed_collection::buffer seed_finding::operator()( stream>>>( m_seedfinder_config, spacepoints_view, g2_view, doublet_counter_buffer, doublet_buffer_mb, doublet_buffer_mt, triplet_counter_spM_buffer, - triplet_counter_midBot_buffer); + triplet_counter_midBot_buffer, midBotLinCircles, midTopLinCircles); TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); // Calculate the number of threads and thread blocks to run the triplet @@ -320,7 +380,7 @@ edm::seed_collection::buffer seed_finding::operator()( // Calculate the number of threads and thread blocks to run the triplet // finding kernel for. - const unsigned int nTripletFindThreads = m_warp_size * 2; + const unsigned int nTripletFindThreads = 128; const unsigned int nTripletFindBlocks = (*size_staging_ptr + nTripletFindThreads - 1) / nTripletFindThreads; @@ -330,7 +390,7 @@ edm::seed_collection::buffer seed_finding::operator()( m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view, doublet_counter_buffer, doublet_buffer_mt, triplet_counter_spM_buffer, triplet_counter_midBot_buffer, - triplet_buffer); + midBotLinCircles, midTopLinCircles, triplet_buffer); TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); // Calculate the number of threads and thread blocks to run the weight diff --git a/device/sycl/src/seeding/seed_finding.sycl b/device/sycl/src/seeding/seed_finding.sycl index cf90d62e1d..cce55f878f 100644 --- a/device/sycl/src/seeding/seed_finding.sycl +++ b/device/sycl/src/seeding/seed_finding.sycl @@ -23,6 +23,8 @@ #include "traccc/seeding/device/count_triplets.hpp" #include "traccc/seeding/device/find_doublets.hpp" #include "traccc/seeding/device/find_triplets.hpp" +#include "traccc/seeding/device/make_mid_bot_lincircles.hpp" +#include "traccc/seeding/device/make_mid_top_lincircles.hpp" #include "traccc/seeding/device/reduce_triplet_counts.hpp" #include "traccc/seeding/device/select_seeds.hpp" #include "traccc/seeding/device/update_triplet_weights.hpp" @@ -42,6 +44,14 @@ class count_doublets; /// Class identifying the kernel running @c traccc::device::find_doublets class find_doublets; +/// Class identifying the kernel running @c +/// traccc::device::make_mid_bot_lincircles +class make_mid_bot_lincircles; + +/// Class identifying the kernel running @c +/// traccc::device::make_mid_top_lincircles +class make_mid_top_lincircles; + /// Class identifying the kernel running @c traccc::device::count_triplets class count_triplets; @@ -181,6 +191,55 @@ edm::seed_collection::buffer seed_finding::operator()( }); }); + // Wait here for the find_doublets kernel to finish + find_doublets_kernel.wait_and_throw(); + + // Create buffers to compute lin circles into + vecmem::data::vector_buffer midBotLinCircles{ + globalCounter_host->m_nMidBot, m_mr.main}; + m_copy.setup(midBotLinCircles)->wait(); + vecmem::data::vector_buffer midTopLinCircles{ + globalCounter_host->m_nMidTop, m_mr.main}; + m_copy.setup(midBotLinCircles)->wait(); + + ::sycl::event makeMidBotLinCirclesKernel, makeMidTopLinCirclesKernel; + + { + const unsigned int nThreads = 128; + const auto midBotRange = details::calculate1DimNdRange( + globalCounter_host->m_nMidBot, nThreads); + const auto midTopRange = details::calculate1DimNdRange( + globalCounter_host->m_nMidBot, nThreads); + + makeMidBotLinCirclesKernel = + details::get_queue(m_queue).submit([&](::sycl::handler& h) { + h.parallel_for( + midBotRange, + [spacepoints_view, g2_view, doublet_counter_view, mb_view, + mid_bot_circles_view = vecmem::get_data(midBotLinCircles)]( + ::sycl::nd_item<1> item) { + device::make_mid_bot_lincircles( + details::global_index(item), mb_view, + doublet_counter_view, spacepoints_view, g2_view, + mid_bot_circles_view); + }); + }); + + makeMidTopLinCirclesKernel = + details::get_queue(m_queue).submit([&](::sycl::handler& h) { + h.parallel_for( + midTopRange, + [spacepoints_view, g2_view, doublet_counter_view, mt_view, + mid_top_circles_view = vecmem::get_data(midTopLinCircles)]( + ::sycl::nd_item<1> item) { + device::make_mid_bot_lincircles( + details::global_index(item), mt_view, + doublet_counter_view, spacepoints_view, g2_view, + mid_top_circles_view); + }); + }); + } + // Set up the triplet counter buffers and their views device::triplet_counter_spM_collection_types::buffer triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main}; @@ -202,8 +261,8 @@ edm::seed_collection::buffer seed_finding::operator()( const auto tripletCountRange = details::calculate1DimNdRange( globalCounter_host->m_nMidBot, tripletCountLocalSize); - // Wait here for the find_doublets kernel to finish - find_doublets_kernel.wait_and_throw(); + makeMidBotLinCirclesKernel.wait_and_throw(); + makeMidTopLinCirclesKernel.wait_and_throw(); // Count the number of triplets that we need to produce. auto count_triplets_kernel = @@ -212,12 +271,15 @@ edm::seed_collection::buffer seed_finding::operator()( tripletCountRange, [config = m_seedfinder_config, spacepoints_view, g2_view, doublet_counter_view, mb_view, mt_view, - triplet_counter_spM_view, - triplet_counter_midBot_view](::sycl::nd_item<1> item) { + triplet_counter_spM_view, triplet_counter_midBot_view, + mid_bot_circles_view = vecmem::get_data(midBotLinCircles), + mid_top_circles_view = vecmem::get_data(midTopLinCircles)]( + ::sycl::nd_item<1> item) { device::count_triplets( details::global_index(item), config, spacepoints_view, g2_view, doublet_counter_view, mb_view, mt_view, - triplet_counter_spM_view, triplet_counter_midBot_view); + triplet_counter_spM_view, triplet_counter_midBot_view, + mid_bot_circles_view, mid_top_circles_view); }); }); @@ -272,13 +334,16 @@ edm::seed_collection::buffer seed_finding::operator()( [config = m_seedfinder_config, filter_config = m_seedfilter_config, spacepoints_view, g2_view, doublet_counter_view, mt_view, triplet_counter_spM_view, - triplet_counter_midBot_view, - triplet_view](::sycl::nd_item<1> item) { + triplet_counter_midBot_view, triplet_view, + mid_bot_circles_view = vecmem::get_data(midBotLinCircles), + mid_top_circles_view = vecmem::get_data(midTopLinCircles)]( + ::sycl::nd_item<1> item) { device::find_triplets( details::global_index(item), config, filter_config, spacepoints_view, g2_view, doublet_counter_view, mt_view, triplet_counter_spM_view, - triplet_counter_midBot_view, triplet_view); + triplet_counter_midBot_view, mid_bot_circles_view, + mid_top_circles_view, triplet_view); }); });