Skip to content

Commit e7fcc07

Browse files
committed
Improve performance of triplet finding
This commit improves the performance of our seed finding by computing the linear circles for the doublets only once rather than computing them multiple times. Also tweaks launch parameters to improve occupancy.
1 parent bc2115f commit e7fcc07

File tree

10 files changed

+358
-63
lines changed

10 files changed

+358
-63
lines changed

device/alpaka/src/seeding/seed_finding.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "traccc/seeding/device/count_triplets.hpp"
2525
#include "traccc/seeding/device/find_doublets.hpp"
2626
#include "traccc/seeding/device/find_triplets.hpp"
27+
#include "traccc/seeding/device/make_mid_bot_lincircles.hpp"
28+
#include "traccc/seeding/device/make_mid_top_lincircles.hpp"
2729
#include "traccc/seeding/device/reduce_triplet_counts.hpp"
2830
#include "traccc/seeding/device/select_seeds.hpp"
2931
#include "traccc/seeding/device/update_triplet_weights.hpp"
@@ -71,6 +73,42 @@ struct FindDoublets {
7173
}
7274
};
7375

76+
// Kernel for running @c traccc::device::make_mid_bot_lincircles
77+
struct MakeMidBotLinCircles {
78+
template <typename TAcc>
79+
ALPAKA_FN_ACC void operator()(
80+
TAcc const& acc,
81+
device::device_doublet_collection_types::const_view mb_doublet_view,
82+
device::doublet_counter_collection_types::const_view doublet_count_view,
83+
edm::spacepoint_collection::const_view spacepoint_view,
84+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
85+
vecmem::data::vector_view<lin_circle> out_view) const {
86+
auto const globalThreadIdx =
87+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
88+
device::make_mid_bot_lincircles(globalThreadIdx, mb_doublet_view,
89+
doublet_count_view, spacepoint_view,
90+
sp_grid_view, out_view);
91+
}
92+
};
93+
94+
// Kernel for running @c traccc::device::make_mid_top_lincircles
95+
struct MakeMidTopLinCircles {
96+
template <typename TAcc>
97+
ALPAKA_FN_ACC void operator()(
98+
TAcc const& acc,
99+
device::device_doublet_collection_types::const_view mt_doublet_view,
100+
device::doublet_counter_collection_types::const_view doublet_count_view,
101+
edm::spacepoint_collection::const_view spacepoint_view,
102+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
103+
vecmem::data::vector_view<lin_circle> out_view) const {
104+
auto const globalThreadIdx =
105+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
106+
device::make_mid_top_lincircles(globalThreadIdx, mt_doublet_view,
107+
doublet_count_view, spacepoint_view,
108+
sp_grid_view, out_view);
109+
}
110+
};
111+
74112
// Kernel for running @c traccc::device::count_triplets
75113
struct CountTriplets {
76114
template <typename TAcc>
@@ -82,12 +120,15 @@ struct CountTriplets {
82120
device::device_doublet_collection_types::const_view mb_doublets,
83121
device::device_doublet_collection_types::const_view mt_doublets,
84122
device::triplet_counter_spM_collection_types::view spM_counter,
85-
device::triplet_counter_collection_types::view midBot_counter) const {
123+
device::triplet_counter_collection_types::view midBot_counter,
124+
vecmem::data::vector_view<lin_circle> mb_circles,
125+
vecmem::data::vector_view<lin_circle> mt_circles) const {
86126
auto const globalThreadIdx =
87127
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
88128
device::count_triplets(globalThreadIdx, config, spacepoints, sp_grid,
89129
doublet_counter, mb_doublets, mt_doublets,
90-
spM_counter, midBot_counter);
130+
spM_counter, midBot_counter, mb_circles,
131+
mt_circles);
91132
}
92133
};
93134

@@ -118,13 +159,16 @@ struct FindTriplets {
118159
device::device_doublet_collection_types::const_view mt_doublets,
119160
device::triplet_counter_spM_collection_types::const_view spM_tc,
120161
device::triplet_counter_collection_types::const_view midBot_tc,
162+
vecmem::data::vector_view<lin_circle> mb_circles,
163+
vecmem::data::vector_view<lin_circle> mt_circles,
121164
device::device_triplet_collection_types::view triplet_view) const {
122165

123166
auto const globalThreadIdx =
124167
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
125168
device::find_triplets(globalThreadIdx, config, filter_config,
126169
spacepoints, sp_grid, doublet_counter,
127-
mt_doublets, spM_tc, midBot_tc, triplet_view);
170+
mt_doublets, spM_tc, midBot_tc, mb_circles,
171+
mt_circles, triplet_view);
128172
}
129173
};
130174

@@ -287,6 +331,37 @@ edm::seed_collection::buffer seed_finding::operator()(
287331
vecmem::get_data(doublet_buffer_mb),
288332
vecmem::get_data(doublet_buffer_mt));
289333

334+
vecmem::data::vector_buffer<lin_circle> mid_bot_lin_circles{
335+
pBufHost_counter->m_nMidBot, m_mr.main};
336+
m_copy.setup(mid_bot_lin_circles)->wait();
337+
vecmem::data::vector_buffer<lin_circle> mid_top_lin_circles{
338+
pBufHost_counter->m_nMidTop, m_mr.main};
339+
m_copy.setup(mid_top_lin_circles)->wait();
340+
341+
{
342+
const unsigned int n_threads = 128;
343+
const unsigned int n_mid_bot_blocks =
344+
(pBufHost_counter->m_nMidBot + n_threads - 1) / n_threads;
345+
const unsigned int n_mid_top_blocks =
346+
(pBufHost_counter->m_nMidTop + n_threads - 1) / n_threads;
347+
const auto mid_bot_workdiv =
348+
makeWorkDiv<Acc>(n_mid_bot_blocks, n_threads);
349+
const auto mid_top_workdiv =
350+
makeWorkDiv<Acc>(n_mid_top_blocks, n_threads);
351+
352+
::alpaka::exec<Acc>(
353+
queue, mid_bot_workdiv, kernels::MakeMidBotLinCircles{},
354+
vecmem::get_data(doublet_buffer_mb),
355+
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
356+
vecmem::get_data(mid_bot_lin_circles));
357+
358+
::alpaka::exec<Acc>(
359+
queue, mid_top_workdiv, kernels::MakeMidTopLinCircles{},
360+
vecmem::get_data(doublet_buffer_mb),
361+
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
362+
vecmem::get_data(mid_top_lin_circles));
363+
}
364+
290365
// Set up the triplet counter buffers
291366
device::triplet_counter_spM_collection_types::buffer
292367
triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main};
@@ -310,7 +385,9 @@ edm::seed_collection::buffer seed_finding::operator()(
310385
vecmem::get_data(doublet_buffer_mb),
311386
vecmem::get_data(doublet_buffer_mt),
312387
vecmem::get_data(triplet_counter_spM_buffer),
313-
vecmem::get_data(triplet_counter_midBot_buffer));
388+
vecmem::get_data(triplet_counter_midBot_buffer),
389+
vecmem::get_data(mid_bot_lin_circles),
390+
vecmem::get_data(mid_top_lin_circles));
314391

315392
// Calculate the number of threads and thread blocks to run the triplet
316393
// count reduction kernel for.
@@ -352,6 +429,8 @@ edm::seed_collection::buffer seed_finding::operator()(
352429
vecmem::get_data(doublet_buffer_mt),
353430
vecmem::get_data(triplet_counter_spM_buffer),
354431
vecmem::get_data(triplet_counter_midBot_buffer),
432+
vecmem::get_data(mid_bot_lin_circles),
433+
vecmem::get_data(mid_top_lin_circles),
355434
vecmem::get_data(triplet_buffer));
356435

357436
blocksPerGrid =

device/common/include/traccc/edm/device/triplet_counter.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct triplet_counter {
4747
/// The position in which these triplets will be added
4848
unsigned int posTriplets = 0;
4949

50+
/// Index of the bottom-middle doublet
51+
unsigned int m_botMidIdx = 0;
52+
5053
}; // struct triplet_counter
5154

5255
/// Declare all triplet counter collection types

device/common/include/traccc/seeding/device/count_triplets.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ inline void count_triplets(
4848
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
triplet_counter_spM_collection_types::view spM_tc,
51-
triplet_counter_collection_types::view mb_tc);
51+
triplet_counter_collection_types::view mb_tc,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circles,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circles);
5254

5355
} // namespace traccc::device
5456

device/common/include/traccc/seeding/device/find_triplets.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ inline void find_triplets(
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
5151
const triplet_counter_collection_types::const_view& tc_view,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
5254
device_triplet_collection_types::view triplet_view);
5355

5456
} // namespace traccc::device

device/common/include/traccc/seeding/device/impl/count_triplets.ipp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ inline void count_triplets(
2424
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
triplet_counter_spM_collection_types::view spM_tc_view,
27-
triplet_counter_collection_types::view mb_tc_view) {
27+
triplet_counter_collection_types::view mb_tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view) {
2830

2931
// Create device copy of input parameters
3032
const device_doublet_collection_types::const_device mid_bot_doublet_device(
@@ -43,6 +45,10 @@ inline void count_triplets(
4345
const device_doublet_collection_types::const_device mid_top_doublet_device(
4446
mid_top_doublet_view);
4547
const doublet_counter_collection_types::const_device dc_device(dc_view);
48+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
49+
mid_bot_circle_view);
50+
const vecmem::device_vector<const lin_circle> mid_top_circles(
51+
mid_top_circle_view);
4652

4753
// Create device copy of output parameterss
4854
triplet_counter_collection_types::device mb_triplet_counter(mb_tc_view);
@@ -61,24 +67,16 @@ inline void count_triplets(
6167
const edm::spacepoint_collection::const_device::const_proxy_type spM =
6268
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
6369
const sp_location spB_loc = mid_bot.sp2;
64-
// bottom spacepoint
65-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
66-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
6770

6871
// Apply the conformal transformation to middle-bot doublet
69-
traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
70-
details::spacepoint_type::bottom>(spM, spB);
72+
const traccc::lin_circle lb = mid_bot_circles.at(globalIndex);
7173

7274
// Calculate some physical quantities required for triplet compatibility
7375
// check
7476
scalar iSinTheta2 = static_cast<scalar>(1.) + lb.cotTheta() * lb.cotTheta();
7577
scalar scatteringInRegion2 = config.maxScatteringAngle2 * iSinTheta2;
7678
scatteringInRegion2 *= config.sigmaScattering * config.sigmaScattering;
7779

78-
// These two quantities are used as output parameters in
79-
// triplet_finding_helper::isCompatible but their values are irrelevant
80-
scalar curvature, impact_parameter;
81-
8280
// find the reference (start) index of the mid-top doublet container
8381
// item vector, where the doublets are recorded
8482
const unsigned int mt_start_idx = doublet_counts.m_posMidTop;
@@ -87,18 +85,16 @@ inline void count_triplets(
8785
// number of triplets per middle-bot doublet
8886
unsigned int num_triplets_per_mb = 0;
8987

90-
// iterate over mid-top doublets
91-
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
92-
const traccc::sp_location spT_loc = mid_top_doublet_device[i].sp2;
93-
94-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
95-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
88+
const unsigned int num_mt = mt_end_idx - mt_start_idx;
9689

90+
// iterate over mid-top doublets
91+
for (unsigned int ri = 0; ri < num_mt; ++ri) {
9792
// Apply the conformal transformation to middle-top doublet
98-
traccc::lin_circle lt = doublet_finding_helper::transform_coordinates<
99-
details::spacepoint_type::top>(spM, spT);
93+
const lin_circle& lt = mid_top_circles.at(ri + mt_start_idx);
10094

101-
// Check if mid-bot and mid-top doublets can form a triplet
95+
// These two quantities are used as output parameters in
96+
// triplet_finding_helper::isCompatible but their values are irrelevant
97+
scalar curvature, impact_parameter;
10298
if (triplet_finding_helper::isCompatible(
10399
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
104100
impact_parameter)) {
@@ -114,8 +110,9 @@ inline void count_triplets(
114110
const unsigned int posTriplets =
115111
nTriplets.fetch_add(num_triplets_per_mb);
116112

117-
mb_triplet_counter.push_back(
118-
{spB_loc, counter_link, num_triplets_per_mb, posTriplets});
113+
mb_triplet_counter.push_back({spB_loc, counter_link,
114+
num_triplets_per_mb, posTriplets,
115+
globalIndex});
119116
}
120117
}
121118

device/common/include/traccc/seeding/device/impl/find_triplets.ipp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ inline void find_triplets(
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
2727
const triplet_counter_collection_types::const_view& tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
2830
device_triplet_collection_types::view triplet_view) {
2931

3032
// Check if anything needs to be done.
@@ -44,6 +46,10 @@ inline void find_triplets(
4446
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
4547
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
4648
spM_tc_view);
49+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
50+
mid_bot_circle_view);
51+
const vecmem::device_vector<const lin_circle> mid_top_circles(
52+
mid_top_circle_view);
4753

4854
// Get the current work item information
4955
const triplet_counter mid_bot_counter = triplet_counts.at(globalIndex);
@@ -53,22 +59,17 @@ inline void find_triplets(
5359
doublet_counts.at(mid_bot_counter.spM_counter_link);
5460

5561
const sp_location spM_loc = spM_counter.spM;
56-
const sp_location spB_loc = mid_bot_counter.spB;
5762

5863
// middle spacepoint
5964
const edm::spacepoint_collection::const_device::const_proxy_type spM =
6065
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
6166

62-
// bottom spacepoint
63-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
64-
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
65-
6667
// Set up the device result collection
6768
device_triplet_collection_types::device triplets(triplet_view);
6869

6970
// Apply the conformal transformation to middle-bot doublet
70-
const traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
71-
details::spacepoint_type::bottom>(spM, spB);
71+
const traccc::lin_circle lb =
72+
mid_bot_circles.at(mid_bot_counter.m_botMidIdx);
7273

7374
// Calculate some physical quantities required for triplet compatibility
7475
// check
@@ -77,10 +78,6 @@ inline void find_triplets(
7778
config.sigmaScattering *
7879
config.sigmaScattering;
7980

80-
// These two quantities are used as output parameters in
81-
// triplet_finding_helper::isCompatible but their values are irrelevant
82-
scalar curvature, impact_parameter;
83-
8481
// find the reference (start) index of the mid-top doublet collection
8582
// item vector, where the doublets are recorded
8683
const unsigned int mt_start_idx = doublet_count.m_posMidTop;
@@ -93,24 +90,18 @@ inline void find_triplets(
9390

9491
// iterate over mid-top doublets
9592
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
96-
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
97-
98-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
99-
spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
100-
10193
// Apply the conformal transformation to middle-top doublet
102-
const traccc::lin_circle lt =
103-
doublet_finding_helper::transform_coordinates<
104-
details::spacepoint_type::top>(spM, spT);
94+
const traccc::lin_circle& lt = mid_top_circles.at(i);
10595

10696
// Check if mid-bot and mid-top doublets can form a triplet
97+
scalar curvature, impact_parameter;
10798
if (triplet_finding_helper::isCompatible(
10899
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
109100
impact_parameter)) {
110101

111102
// Add triplet to jagged vector
112103
triplets.at(posTriplets++) = device_triplet(
113-
{spT_loc, globalIndex, curvature,
104+
{mid_top_doublet_device[i].sp2, globalIndex, curvature,
114105
-impact_parameter * filter_config.impactWeightFactor,
115106
lb.Zo()});
116107
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::device {
11+
/**
12+
* @brief Kernel to create middle-bottom linearised circles.
13+
*/
14+
TRACCC_HOST_DEVICE
15+
inline void make_mid_bot_lincircles(
16+
global_index_t tid,
17+
device::device_doublet_collection_types::const_view mb_doublet_view,
18+
device::doublet_counter_collection_types::const_view doublet_count_view,
19+
edm::spacepoint_collection::const_view spacepoint_view,
20+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
21+
vecmem::data::vector_view<lin_circle> out_view) {
22+
23+
const device::device_doublet_collection_types::const_device doublets(
24+
mb_doublet_view);
25+
const device::doublet_counter_collection_types::const_device doublet_counts(
26+
doublet_count_view);
27+
const edm::spacepoint_collection::const_device spacepoints(spacepoint_view);
28+
traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view);
29+
vecmem::device_vector<lin_circle> out(out_view);
30+
31+
if (tid >= doublets.size()) {
32+
return;
33+
}
34+
35+
const device::device_doublet dub = doublets.at(tid);
36+
const unsigned int counter_link = dub.counter_link;
37+
const device::doublet_counter count = doublet_counts.at(counter_link);
38+
const sp_location spM_loc = count.m_spM;
39+
const edm::spacepoint_collection::const_device::const_proxy_type spM =
40+
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
41+
const sp_location spB_loc = dub.sp2;
42+
const edm::spacepoint_collection::const_device::const_proxy_type spB =
43+
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
44+
45+
out.at(tid) = doublet_finding_helper::transform_coordinates<
46+
traccc::details::spacepoint_type::bottom>(spM, spB);
47+
}
48+
} // namespace traccc::device

0 commit comments

Comments
 (0)