Skip to content

Commit 1d331cd

Browse files
committed
Improve performance of triplet finding
This commit improves the performance of our seed finding by computing the linear circles for the doublets only once rather than computing them multiple times. Also tweaks launch parameters to improve occupancy.
1 parent d6d6045 commit 1d331cd

File tree

10 files changed

+359
-57
lines changed

10 files changed

+359
-57
lines changed

device/alpaka/src/seeding/seed_finding.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "traccc/seeding/device/count_triplets.hpp"
2525
#include "traccc/seeding/device/find_doublets.hpp"
2626
#include "traccc/seeding/device/find_triplets.hpp"
27+
#include "traccc/seeding/device/make_mid_bot_lincircles.hpp"
28+
#include "traccc/seeding/device/make_mid_top_lincircles.hpp"
2729
#include "traccc/seeding/device/reduce_triplet_counts.hpp"
2830
#include "traccc/seeding/device/select_seeds.hpp"
2931
#include "traccc/seeding/device/update_triplet_weights.hpp"
@@ -71,6 +73,42 @@ struct FindDoublets {
7173
}
7274
};
7375

76+
// Kernel for running @c traccc::device::make_mid_bot_lincircles
77+
struct MakeMidBotLinCircles {
78+
template <typename TAcc>
79+
ALPAKA_FN_ACC void operator()(
80+
TAcc const& acc,
81+
device::device_doublet_collection_types::const_view mb_doublet_view,
82+
device::doublet_counter_collection_types::const_view doublet_count_view,
83+
edm::spacepoint_collection::const_view spacepoint_view,
84+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
85+
vecmem::data::vector_view<lin_circle> out_view) const {
86+
auto const globalThreadIdx =
87+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
88+
device::make_mid_bot_lincircles(globalThreadIdx, mb_doublet_view,
89+
doublet_count_view, spacepoint_view,
90+
sp_grid_view, out_view);
91+
}
92+
};
93+
94+
// Kernel for running @c traccc::device::make_mid_top_lincircles
95+
struct MakeMidTopLinCircles {
96+
template <typename TAcc>
97+
ALPAKA_FN_ACC void operator()(
98+
TAcc const& acc,
99+
device::device_doublet_collection_types::const_view mt_doublet_view,
100+
device::doublet_counter_collection_types::const_view doublet_count_view,
101+
edm::spacepoint_collection::const_view spacepoint_view,
102+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
103+
vecmem::data::vector_view<lin_circle> out_view) const {
104+
auto const globalThreadIdx =
105+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
106+
device::make_mid_top_lincircles(globalThreadIdx, mt_doublet_view,
107+
doublet_count_view, spacepoint_view,
108+
sp_grid_view, out_view);
109+
}
110+
};
111+
74112
// Kernel for running @c traccc::device::count_triplets
75113
struct CountTriplets {
76114
template <typename TAcc>
@@ -82,12 +120,15 @@ struct CountTriplets {
82120
device::device_doublet_collection_types::const_view mb_doublets,
83121
device::device_doublet_collection_types::const_view mt_doublets,
84122
device::triplet_counter_spM_collection_types::view spM_counter,
85-
device::triplet_counter_collection_types::view midBot_counter) const {
123+
device::triplet_counter_collection_types::view midBot_counter,
124+
vecmem::data::vector_view<lin_circle> mb_circles,
125+
vecmem::data::vector_view<lin_circle> mt_circles) const {
86126
auto const globalThreadIdx =
87127
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
88128
device::count_triplets(globalThreadIdx, config, spacepoints, sp_grid,
89129
doublet_counter, mb_doublets, mt_doublets,
90-
spM_counter, midBot_counter);
130+
spM_counter, midBot_counter, mb_circles,
131+
mt_circles);
91132
}
92133
};
93134

@@ -118,13 +159,16 @@ struct FindTriplets {
118159
device::device_doublet_collection_types::const_view mt_doublets,
119160
device::triplet_counter_spM_collection_types::const_view spM_tc,
120161
device::triplet_counter_collection_types::const_view midBot_tc,
162+
vecmem::data::vector_view<lin_circle> mb_circles,
163+
vecmem::data::vector_view<lin_circle> mt_circles,
121164
device::device_triplet_collection_types::view triplet_view) const {
122165

123166
auto const globalThreadIdx =
124167
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0u];
125168
device::find_triplets(globalThreadIdx, config, filter_config,
126169
spacepoints, sp_grid, doublet_counter,
127-
mt_doublets, spM_tc, midBot_tc, triplet_view);
170+
mt_doublets, spM_tc, midBot_tc, mb_circles,
171+
mt_circles, triplet_view);
128172
}
129173
};
130174

@@ -288,6 +332,37 @@ edm::seed_collection::buffer seed_finding::operator()(
288332
vecmem::get_data(doublet_buffer_mb),
289333
vecmem::get_data(doublet_buffer_mt));
290334

335+
vecmem::data::vector_buffer<lin_circle> mid_bot_lin_circles{
336+
pBufHost_counter->m_nMidBot, m_mr.main};
337+
m_copy.setup(mid_bot_lin_circles)->wait();
338+
vecmem::data::vector_buffer<lin_circle> mid_top_lin_circles{
339+
pBufHost_counter->m_nMidTop, m_mr.main};
340+
m_copy.setup(mid_top_lin_circles)->wait();
341+
342+
{
343+
const unsigned int n_threads = 128;
344+
const unsigned int n_mid_bot_blocks =
345+
(pBufHost_counter->m_nMidBot + n_threads - 1) / n_threads;
346+
const unsigned int n_mid_top_blocks =
347+
(pBufHost_counter->m_nMidTop + n_threads - 1) / n_threads;
348+
const auto mid_bot_workdiv =
349+
makeWorkDiv<Acc>(n_mid_bot_blocks, n_threads);
350+
const auto mid_top_workdiv =
351+
makeWorkDiv<Acc>(n_mid_top_blocks, n_threads);
352+
353+
::alpaka::exec<Acc>(
354+
queue, mid_bot_workdiv, kernels::MakeMidBotLinCircles{},
355+
vecmem::get_data(doublet_buffer_mb),
356+
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
357+
vecmem::get_data(mid_bot_lin_circles));
358+
359+
::alpaka::exec<Acc>(
360+
queue, mid_top_workdiv, kernels::MakeMidTopLinCircles{},
361+
vecmem::get_data(doublet_buffer_mb),
362+
vecmem::get_data(doublet_counter_buffer), spacepoints_view, g2_view,
363+
vecmem::get_data(mid_top_lin_circles));
364+
}
365+
291366
// Set up the triplet counter buffers
292367
device::triplet_counter_spM_collection_types::buffer
293368
triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main};
@@ -311,7 +386,9 @@ edm::seed_collection::buffer seed_finding::operator()(
311386
vecmem::get_data(doublet_buffer_mb),
312387
vecmem::get_data(doublet_buffer_mt),
313388
vecmem::get_data(triplet_counter_spM_buffer),
314-
vecmem::get_data(triplet_counter_midBot_buffer));
389+
vecmem::get_data(triplet_counter_midBot_buffer),
390+
vecmem::get_data(mid_bot_lin_circles),
391+
vecmem::get_data(mid_top_lin_circles));
315392

316393
// Calculate the number of threads and thread blocks to run the triplet
317394
// count reduction kernel for.
@@ -353,6 +430,8 @@ edm::seed_collection::buffer seed_finding::operator()(
353430
vecmem::get_data(doublet_buffer_mt),
354431
vecmem::get_data(triplet_counter_spM_buffer),
355432
vecmem::get_data(triplet_counter_midBot_buffer),
433+
vecmem::get_data(mid_bot_lin_circles),
434+
vecmem::get_data(mid_top_lin_circles),
356435
vecmem::get_data(triplet_buffer));
357436

358437
blocksPerGrid =

device/common/include/traccc/edm/device/triplet_counter.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct triplet_counter {
4747
/// The position in which these triplets will be added
4848
unsigned int posTriplets = 0;
4949

50+
/// Index of the bottom-middle doublet
51+
unsigned int m_botMidIdx = 0;
52+
5053
}; // struct triplet_counter
5154

5255
/// Declare all triplet counter collection types

device/common/include/traccc/seeding/device/count_triplets.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ inline void count_triplets(
4848
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
triplet_counter_spM_collection_types::view spM_tc,
51-
triplet_counter_collection_types::view mb_tc);
51+
triplet_counter_collection_types::view mb_tc,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circles,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circles);
5254

5355
} // namespace traccc::device
5456

device/common/include/traccc/seeding/device/find_triplets.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ inline void find_triplets(
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
5151
const triplet_counter_collection_types::const_view& tc_view,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
5254
device_triplet_collection_types::view triplet_view);
5355

5456
} // namespace traccc::device

device/common/include/traccc/seeding/device/impl/count_triplets.ipp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ inline void count_triplets(
2424
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
triplet_counter_spM_collection_types::view spM_tc_view,
27-
triplet_counter_collection_types::view mb_tc_view) {
27+
triplet_counter_collection_types::view mb_tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view) {
2830

2931
// Create device copy of input parameters
3032
const device_doublet_collection_types::const_device mid_bot_doublet_device(
@@ -43,6 +45,10 @@ inline void count_triplets(
4345
const device_doublet_collection_types::const_device mid_top_doublet_device(
4446
mid_top_doublet_view);
4547
const doublet_counter_collection_types::const_device dc_device(dc_view);
48+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
49+
mid_bot_circle_view);
50+
const vecmem::device_vector<const lin_circle> mid_top_circles(
51+
mid_top_circle_view);
4652

4753
// Create device copy of output parameterss
4854
triplet_counter_collection_types::device mb_triplet_counter(mb_tc_view);
@@ -61,24 +67,16 @@ inline void count_triplets(
6167
const edm::spacepoint_collection::const_device::const_proxy_type spM =
6268
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
6369
const sp_location spB_loc = mid_bot.sp2;
64-
// bottom spacepoint
65-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
66-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
6770

6871
// Apply the conformal transformation to middle-bot doublet
69-
traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
70-
details::spacepoint_type::bottom>(spM, spB);
72+
const traccc::lin_circle lb = mid_bot_circles.at(globalIndex);
7173

7274
// Calculate some physical quantities required for triplet compatibility
7375
// check
7476
scalar iSinTheta2 = static_cast<scalar>(1.) + lb.cotTheta() * lb.cotTheta();
7577
scalar scatteringInRegion2 = config.maxScatteringAngle2 * iSinTheta2;
7678
scatteringInRegion2 *= config.sigmaScattering * config.sigmaScattering;
7779

78-
// These two quantities are used as output parameters in
79-
// triplet_finding_helper::isCompatible but their values are irrelevant
80-
scalar curvature, impact_parameter;
81-
8280
// find the reference (start) index of the mid-top doublet container
8381
// item vector, where the doublets are recorded
8482
const unsigned int mt_start_idx = doublet_counts.m_posMidTop;
@@ -87,18 +85,16 @@ inline void count_triplets(
8785
// number of triplets per middle-bot doublet
8886
unsigned int num_triplets_per_mb = 0;
8987

90-
// iterate over mid-top doublets
91-
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
92-
const traccc::sp_location spT_loc = mid_top_doublet_device[i].sp2;
93-
94-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
95-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
88+
const unsigned int num_mt = mt_end_idx - mt_start_idx;
9689

90+
// iterate over mid-top doublets
91+
for (unsigned int ri = 0; ri < num_mt; ++ri) {
9792
// Apply the conformal transformation to middle-top doublet
98-
traccc::lin_circle lt = doublet_finding_helper::transform_coordinates<
99-
details::spacepoint_type::top>(spM, spT);
93+
const lin_circle& lt = mid_top_circles.at(ri + mt_start_idx);
10094

101-
// Check if mid-bot and mid-top doublets can form a triplet
95+
// These two quantities are used as output parameters in
96+
// triplet_finding_helper::isCompatible but their values are irrelevant
97+
scalar curvature, impact_parameter;
10298
if (triplet_finding_helper::isCompatible(
10399
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
104100
impact_parameter)) {
@@ -114,8 +110,9 @@ inline void count_triplets(
114110
const unsigned int posTriplets =
115111
nTriplets.fetch_add(num_triplets_per_mb);
116112

117-
mb_triplet_counter.push_back(
118-
{spB_loc, counter_link, num_triplets_per_mb, posTriplets});
113+
mb_triplet_counter.push_back({spB_loc, counter_link,
114+
num_triplets_per_mb, posTriplets,
115+
globalIndex});
119116
}
120117
}
121118

device/common/include/traccc/seeding/device/impl/find_triplets.ipp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ inline void find_triplets(
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
2727
const triplet_counter_collection_types::const_view& tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
2830
device_triplet_collection_types::view triplet_view) {
2931

3032
// Check if anything needs to be done.
@@ -44,6 +46,10 @@ inline void find_triplets(
4446
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
4547
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
4648
spM_tc_view);
49+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
50+
mid_bot_circle_view);
51+
const vecmem::device_vector<const lin_circle> mid_top_circles(
52+
mid_top_circle_view);
4753

4854
// Get the current work item information
4955
const triplet_counter mid_bot_counter = triplet_counts.at(globalIndex);
@@ -69,8 +75,8 @@ inline void find_triplets(
6975
device_triplet_collection_types::device triplets(triplet_view);
7076

7177
// Apply the conformal transformation to middle-bot doublet
72-
const traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
73-
details::spacepoint_type::bottom>(spM, spB);
78+
const traccc::lin_circle lb =
79+
mid_bot_circles.at(mid_bot_counter.m_botMidIdx);
7480

7581
// Calculate some physical quantities required for triplet compatibility
7682
// check
@@ -79,10 +85,6 @@ inline void find_triplets(
7985
config.sigmaScattering *
8086
config.sigmaScattering;
8187

82-
// These two quantities are used as output parameters in
83-
// triplet_finding_helper::isCompatible but their values are irrelevant
84-
scalar curvature, impact_parameter;
85-
8688
// find the reference (start) index of the mid-top doublet collection
8789
// item vector, where the doublets are recorded
8890
const unsigned int mt_start_idx = doublet_count.m_posMidTop;
@@ -95,19 +97,15 @@ inline void find_triplets(
9597

9698
// iterate over mid-top doublets
9799
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
98-
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
100+
// Apply the conformal transformation to middle-top doublet
101+
const traccc::lin_circle& lt = mid_top_circles.at(i);
99102

103+
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
100104
const unsigned int spT_idx =
101105
sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx];
102-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
103-
spacepoints.at(spT_idx);
104-
105-
// Apply the conformal transformation to middle-top doublet
106-
const traccc::lin_circle lt =
107-
doublet_finding_helper::transform_coordinates<
108-
details::spacepoint_type::top>(spM, spT);
109106

110107
// Check if mid-bot and mid-top doublets can form a triplet
108+
scalar curvature, impact_parameter;
111109
if (triplet_finding_helper::isCompatible(
112110
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
113111
impact_parameter)) {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::device {
11+
/**
12+
* @brief Kernel to create middle-bottom linearised circles.
13+
*/
14+
TRACCC_HOST_DEVICE
15+
inline void make_mid_bot_lincircles(
16+
global_index_t tid,
17+
device::device_doublet_collection_types::const_view mb_doublet_view,
18+
device::doublet_counter_collection_types::const_view doublet_count_view,
19+
edm::spacepoint_collection::const_view spacepoint_view,
20+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
21+
vecmem::data::vector_view<lin_circle> out_view) {
22+
23+
const device::device_doublet_collection_types::const_device doublets(
24+
mb_doublet_view);
25+
const device::doublet_counter_collection_types::const_device doublet_counts(
26+
doublet_count_view);
27+
const edm::spacepoint_collection::const_device spacepoints(spacepoint_view);
28+
traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view);
29+
vecmem::device_vector<lin_circle> out(out_view);
30+
31+
if (tid >= doublets.size()) {
32+
return;
33+
}
34+
35+
const device::device_doublet dub = doublets.at(tid);
36+
const unsigned int counter_link = dub.counter_link;
37+
const device::doublet_counter count = doublet_counts.at(counter_link);
38+
const sp_location spM_loc = count.m_spM;
39+
const edm::spacepoint_collection::const_device::const_proxy_type spM =
40+
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
41+
const sp_location spB_loc = dub.sp2;
42+
const edm::spacepoint_collection::const_device::const_proxy_type spB =
43+
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
44+
45+
out.at(tid) = doublet_finding_helper::transform_coordinates<
46+
traccc::details::spacepoint_type::bottom>(spM, spB);
47+
}
48+
} // namespace traccc::device

0 commit comments

Comments
 (0)