Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions device/alpaka/src/seeding/seed_finding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "traccc/seeding/device/count_triplets.hpp"
#include "traccc/seeding/device/find_doublets.hpp"
#include "traccc/seeding/device/find_triplets.hpp"
#include "traccc/seeding/device/make_mid_bot_lincircles.hpp"
#include "traccc/seeding/device/make_mid_top_lincircles.hpp"
#include "traccc/seeding/device/reduce_triplet_counts.hpp"
#include "traccc/seeding/device/select_seeds.hpp"
#include "traccc/seeding/device/update_triplet_weights.hpp"
Expand Down Expand Up @@ -71,6 +73,42 @@ struct FindDoublets {
}
};

// Kernel for running @c traccc::device::make_mid_bot_lincircles
struct MakeMidBotLinCircles {
template <typename TAcc>
ALPAKA_FN_ACC void operator()(
TAcc const& acc,
device::device_doublet_collection_types::const_view mb_doublet_view,
device::doublet_counter_collection_types::const_view doublet_count_view,
edm::spacepoint_collection::const_view spacepoint_view,
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
vecmem::data::vector_view<lin_circle> out_view) const {
auto const globalThreadIdx =
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
device::make_mid_bot_lincircles(globalThreadIdx, mb_doublet_view,
doublet_count_view, spacepoint_view,
sp_grid_view, out_view);
}
};

// Kernel for running @c traccc::device::make_mid_top_lincircles
struct MakeMidTopLinCircles {
template <typename TAcc>
ALPAKA_FN_ACC void operator()(
TAcc const& acc,
device::device_doublet_collection_types::const_view mt_doublet_view,
device::doublet_counter_collection_types::const_view doublet_count_view,
edm::spacepoint_collection::const_view spacepoint_view,
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
vecmem::data::vector_view<lin_circle> out_view) const {
auto const globalThreadIdx =
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
device::make_mid_top_lincircles(globalThreadIdx, mt_doublet_view,
doublet_count_view, spacepoint_view,
sp_grid_view, out_view);
}
};

// Kernel for running @c traccc::device::count_triplets
struct CountTriplets {
template <typename TAcc>
Expand All @@ -82,12 +120,15 @@ struct CountTriplets {
device::device_doublet_collection_types::const_view mb_doublets,
device::device_doublet_collection_types::const_view mt_doublets,
device::triplet_counter_spM_collection_types::view spM_counter,
device::triplet_counter_collection_types::view midBot_counter) const {
device::triplet_counter_collection_types::view midBot_counter,
vecmem::data::vector_view<lin_circle> mb_circles,
vecmem::data::vector_view<lin_circle> mt_circles) const {
auto const globalThreadIdx =
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
device::count_triplets(globalThreadIdx, config, spacepoints, sp_grid,
doublet_counter, mb_doublets, mt_doublets,
spM_counter, midBot_counter);
spM_counter, midBot_counter, mb_circles,
mt_circles);
}
};

Expand Down Expand Up @@ -118,13 +159,16 @@ struct FindTriplets {
device::device_doublet_collection_types::const_view mt_doublets,
device::triplet_counter_spM_collection_types::const_view spM_tc,
device::triplet_counter_collection_types::const_view midBot_tc,
vecmem::data::vector_view<lin_circle> mb_circles,
vecmem::data::vector_view<lin_circle> mt_circles,
device::device_triplet_collection_types::view triplet_view) const {

auto const globalThreadIdx =
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
device::find_triplets(globalThreadIdx, config, filter_config,
spacepoints, sp_grid, doublet_counter,
mt_doublets, spM_tc, midBot_tc, triplet_view);
mt_doublets, spM_tc, midBot_tc, mb_circles,
mt_circles, triplet_view);
}
};

Expand Down Expand Up @@ -288,6 +332,37 @@ edm::seed_collection::buffer seed_finding::operator()(
vecmem::get_data(doublet_buffer_mb),
vecmem::get_data(doublet_buffer_mt));

vecmem::data::vector_buffer<lin_circle> mid_bot_lin_circles{
pBufHost_counter->m_nMidBot, m_mr.main};
m_copy.setup(mid_bot_lin_circles)->wait();
vecmem::data::vector_buffer<lin_circle> mid_top_lin_circles{
pBufHost_counter->m_nMidTop, m_mr.main};
m_copy.setup(mid_top_lin_circles)->wait();

{
const unsigned int n_threads = 128;
const unsigned int n_mid_bot_blocks =
(pBufHost_counter->m_nMidBot + n_threads - 1) / n_threads;
const unsigned int n_mid_top_blocks =
(pBufHost_counter->m_nMidTop + n_threads - 1) / n_threads;
const auto mid_bot_workdiv =
makeWorkDiv<Acc>(n_mid_bot_blocks, n_threads);
const auto mid_top_workdiv =
makeWorkDiv<Acc>(n_mid_top_blocks, n_threads);

::alpaka::exec<Acc>(
queue, mid_bot_workdiv, kernels::MakeMidBotLinCircles{},
vecmem::get_data(doublet_buffer_mb),
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
vecmem::get_data(mid_bot_lin_circles));

::alpaka::exec<Acc>(
queue, mid_top_workdiv, kernels::MakeMidTopLinCircles{},
vecmem::get_data(doublet_buffer_mb),
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
vecmem::get_data(mid_top_lin_circles));
}

// Set up the triplet counter buffers
device::triplet_counter_spM_collection_types::buffer
triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main};
Expand All @@ -311,7 +386,9 @@ edm::seed_collection::buffer seed_finding::operator()(
vecmem::get_data(doublet_buffer_mb),
vecmem::get_data(doublet_buffer_mt),
vecmem::get_data(triplet_counter_spM_buffer),
vecmem::get_data(triplet_counter_midBot_buffer));
vecmem::get_data(triplet_counter_midBot_buffer),
vecmem::get_data(mid_bot_lin_circles),
vecmem::get_data(mid_top_lin_circles));

// Calculate the number of threads and thread blocks to run the triplet
// count reduction kernel for.
Expand Down Expand Up @@ -353,6 +430,8 @@ edm::seed_collection::buffer seed_finding::operator()(
vecmem::get_data(doublet_buffer_mt),
vecmem::get_data(triplet_counter_spM_buffer),
vecmem::get_data(triplet_counter_midBot_buffer),
vecmem::get_data(mid_bot_lin_circles),
vecmem::get_data(mid_top_lin_circles),
vecmem::get_data(triplet_buffer));

blocksPerGrid =
Expand Down
3 changes: 3 additions & 0 deletions device/common/include/traccc/edm/device/triplet_counter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ struct triplet_counter {
/// The position in which these triplets will be added
unsigned int posTriplets = 0;

/// Index of the bottom-middle doublet
unsigned int m_botMidIdx = 0;

}; // struct triplet_counter

/// Declare all triplet counter collection types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ inline void count_triplets(
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
const device_doublet_collection_types::const_view& mid_top_doublet_view,
triplet_counter_spM_collection_types::view spM_tc,
triplet_counter_collection_types::view mb_tc);
triplet_counter_collection_types::view mb_tc,
vecmem::data::vector_view<const lin_circle> mid_bot_circles,
vecmem::data::vector_view<const lin_circle> mid_top_circles);

} // namespace traccc::device

Expand Down
2 changes: 2 additions & 0 deletions device/common/include/traccc/seeding/device/find_triplets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ inline void find_triplets(
const device_doublet_collection_types::const_view& mid_top_doublet_view,
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view,
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
device_triplet_collection_types::view triplet_view);

} // namespace traccc::device
Expand Down
39 changes: 18 additions & 21 deletions device/common/include/traccc/seeding/device/impl/count_triplets.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ inline void count_triplets(
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
const device_doublet_collection_types::const_view& mid_top_doublet_view,
triplet_counter_spM_collection_types::view spM_tc_view,
triplet_counter_collection_types::view mb_tc_view) {
triplet_counter_collection_types::view mb_tc_view,
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
vecmem::data::vector_view<const lin_circle> mid_top_circle_view) {

// Create device copy of input parameters
const device_doublet_collection_types::const_device mid_bot_doublet_device(
Expand All @@ -43,6 +45,10 @@ inline void count_triplets(
const device_doublet_collection_types::const_device mid_top_doublet_device(
mid_top_doublet_view);
const doublet_counter_collection_types::const_device dc_device(dc_view);
const vecmem::device_vector<const lin_circle> mid_bot_circles(
mid_bot_circle_view);
const vecmem::device_vector<const lin_circle> mid_top_circles(
mid_top_circle_view);

// Create device copy of output parameterss
triplet_counter_collection_types::device mb_triplet_counter(mb_tc_view);
Expand All @@ -61,24 +67,16 @@ inline void count_triplets(
const edm::spacepoint_collection::const_device::const_proxy_type spM =
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
const sp_location spB_loc = mid_bot.sp2;
// bottom spacepoint
const edm::spacepoint_collection::const_device::const_proxy_type spB =
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);

// Apply the conformal transformation to middle-bot doublet
traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
details::spacepoint_type::bottom>(spM, spB);
const traccc::lin_circle lb = mid_bot_circles.at(globalIndex);

// Calculate some physical quantities required for triplet compatibility
// check
scalar iSinTheta2 = static_cast<scalar>(1.) + lb.cotTheta() * lb.cotTheta();
scalar scatteringInRegion2 = config.maxScatteringAngle2 * iSinTheta2;
scatteringInRegion2 *= config.sigmaScattering * config.sigmaScattering;

// These two quantities are used as output parameters in
// triplet_finding_helper::isCompatible but their values are irrelevant
scalar curvature, impact_parameter;

// find the reference (start) index of the mid-top doublet container
// item vector, where the doublets are recorded
const unsigned int mt_start_idx = doublet_counts.m_posMidTop;
Expand All @@ -87,18 +85,16 @@ inline void count_triplets(
// number of triplets per middle-bot doublet
unsigned int num_triplets_per_mb = 0;

// iterate over mid-top doublets
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
const traccc::sp_location spT_loc = mid_top_doublet_device[i].sp2;

const edm::spacepoint_collection::const_device::const_proxy_type spT =
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
const unsigned int num_mt = mt_end_idx - mt_start_idx;

// iterate over mid-top doublets
for (unsigned int ri = 0; ri < num_mt; ++ri) {
// Apply the conformal transformation to middle-top doublet
traccc::lin_circle lt = doublet_finding_helper::transform_coordinates<
details::spacepoint_type::top>(spM, spT);
const lin_circle& lt = mid_top_circles.at(ri + mt_start_idx);

// Check if mid-bot and mid-top doublets can form a triplet
// These two quantities are used as output parameters in
// triplet_finding_helper::isCompatible but their values are irrelevant
scalar curvature, impact_parameter;
if (triplet_finding_helper::isCompatible(
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
impact_parameter)) {
Expand All @@ -114,8 +110,9 @@ inline void count_triplets(
const unsigned int posTriplets =
nTriplets.fetch_add(num_triplets_per_mb);

mb_triplet_counter.push_back(
{spB_loc, counter_link, num_triplets_per_mb, posTriplets});
mb_triplet_counter.push_back({spB_loc, counter_link,
num_triplets_per_mb, posTriplets,
globalIndex});
}
}

Expand Down
26 changes: 12 additions & 14 deletions device/common/include/traccc/seeding/device/impl/find_triplets.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ inline void find_triplets(
const device_doublet_collection_types::const_view& mid_top_doublet_view,
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view,
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
device_triplet_collection_types::view triplet_view) {

// Check if anything needs to be done.
Expand All @@ -44,6 +46,10 @@ inline void find_triplets(
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
spM_tc_view);
const vecmem::device_vector<const lin_circle> mid_bot_circles(
mid_bot_circle_view);
const vecmem::device_vector<const lin_circle> mid_top_circles(
mid_top_circle_view);

// Get the current work item information
const triplet_counter mid_bot_counter = triplet_counts.at(globalIndex);
Expand All @@ -69,8 +75,8 @@ inline void find_triplets(
device_triplet_collection_types::device triplets(triplet_view);

// Apply the conformal transformation to middle-bot doublet
const traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
details::spacepoint_type::bottom>(spM, spB);
const traccc::lin_circle lb =
mid_bot_circles.at(mid_bot_counter.m_botMidIdx);

// Calculate some physical quantities required for triplet compatibility
// check
Expand All @@ -79,10 +85,6 @@ inline void find_triplets(
config.sigmaScattering *
config.sigmaScattering;

// These two quantities are used as output parameters in
// triplet_finding_helper::isCompatible but their values are irrelevant
scalar curvature, impact_parameter;

// find the reference (start) index of the mid-top doublet collection
// item vector, where the doublets are recorded
const unsigned int mt_start_idx = doublet_count.m_posMidTop;
Expand All @@ -95,19 +97,15 @@ inline void find_triplets(

// iterate over mid-top doublets
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
// Apply the conformal transformation to middle-top doublet
const traccc::lin_circle& lt = mid_top_circles.at(i);

const sp_location spT_loc = mid_top_doublet_device[i].sp2;
const unsigned int spT_idx =
sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx];
const edm::spacepoint_collection::const_device::const_proxy_type spT =
spacepoints.at(spT_idx);

// Apply the conformal transformation to middle-top doublet
const traccc::lin_circle lt =
doublet_finding_helper::transform_coordinates<
details::spacepoint_type::top>(spM, spT);

// Check if mid-bot and mid-top doublets can form a triplet
scalar curvature, impact_parameter;
if (triplet_finding_helper::isCompatible(
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
impact_parameter)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

namespace traccc::device {
/**
* @brief Kernel to create middle-bottom linearised circles.
*/
TRACCC_HOST_DEVICE
inline void make_mid_bot_lincircles(
global_index_t tid,
device::device_doublet_collection_types::const_view mb_doublet_view,
device::doublet_counter_collection_types::const_view doublet_count_view,
edm::spacepoint_collection::const_view spacepoint_view,
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
vecmem::data::vector_view<lin_circle> out_view) {

const device::device_doublet_collection_types::const_device doublets(
mb_doublet_view);
const device::doublet_counter_collection_types::const_device doublet_counts(
doublet_count_view);
const edm::spacepoint_collection::const_device spacepoints(spacepoint_view);
traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view);
vecmem::device_vector<lin_circle> out(out_view);

if (tid >= doublets.size()) {
return;
}

const device::device_doublet dub = doublets.at(tid);
const unsigned int counter_link = dub.counter_link;
const device::doublet_counter count = doublet_counts.at(counter_link);
const sp_location spM_loc = count.m_spM;
const edm::spacepoint_collection::const_device::const_proxy_type spM =
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
const sp_location spB_loc = dub.sp2;
const edm::spacepoint_collection::const_device::const_proxy_type spB =
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);

out.at(tid) = doublet_finding_helper::transform_coordinates<
traccc::details::spacepoint_type::bottom>(spM, spB);
}
} // namespace traccc::device
Loading
Loading