Skip to content

Commit ef0e96f

Browse files
committed
Cache linearised circles in seed finding
This commit improves the performance of our seed finding by computing the linear circles for the doublets only once rather than computing them multiple times.
1 parent bc2115f commit ef0e96f

File tree

6 files changed

+150
-51
lines changed

6 files changed

+150
-51
lines changed

device/common/include/traccc/edm/device/triplet_counter.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct triplet_counter {
4747
/// The position in which these triplets will be added
4848
unsigned int posTriplets = 0;
4949

50+
/// Index of the bottom-middle doublet
51+
unsigned int m_botMidIdx = 0;
52+
5053
}; // struct triplet_counter
5154

5255
/// Declare all triplet counter collection types

device/common/include/traccc/seeding/device/count_triplets.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ inline void count_triplets(
4848
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
triplet_counter_spM_collection_types::view spM_tc,
51-
triplet_counter_collection_types::view mb_tc);
51+
triplet_counter_collection_types::view mb_tc,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circles,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circles);
5254

5355
} // namespace traccc::device
5456

device/common/include/traccc/seeding/device/find_triplets.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ inline void find_triplets(
4949
const device_doublet_collection_types::const_view& mid_top_doublet_view,
5050
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
5151
const triplet_counter_collection_types::const_view& tc_view,
52+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
53+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
5254
device_triplet_collection_types::view triplet_view);
5355

5456
} // namespace traccc::device

device/common/include/traccc/seeding/device/impl/count_triplets.ipp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ inline void count_triplets(
2424
const device_doublet_collection_types::const_view& mid_bot_doublet_view,
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
triplet_counter_spM_collection_types::view spM_tc_view,
27-
triplet_counter_collection_types::view mb_tc_view) {
27+
triplet_counter_collection_types::view mb_tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view) {
2830

2931
// Create device copy of input parameters
3032
const device_doublet_collection_types::const_device mid_bot_doublet_device(
@@ -43,6 +45,10 @@ inline void count_triplets(
4345
const device_doublet_collection_types::const_device mid_top_doublet_device(
4446
mid_top_doublet_view);
4547
const doublet_counter_collection_types::const_device dc_device(dc_view);
48+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
49+
mid_bot_circle_view);
50+
const vecmem::device_vector<const lin_circle> mid_top_circles(
51+
mid_top_circle_view);
4652

4753
// Create device copy of output parameterss
4854
triplet_counter_collection_types::device mb_triplet_counter(mb_tc_view);
@@ -61,24 +67,16 @@ inline void count_triplets(
6167
const edm::spacepoint_collection::const_device::const_proxy_type spM =
6268
spacepoints.at(sp_device.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
6369
const sp_location spB_loc = mid_bot.sp2;
64-
// bottom spacepoint
65-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
66-
spacepoints.at(sp_device.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
6770

6871
// Apply the conformal transformation to middle-bot doublet
69-
traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
70-
details::spacepoint_type::bottom>(spM, spB);
72+
const traccc::lin_circle lb = mid_bot_circles.at(globalIndex);
7173

7274
// Calculate some physical quantities required for triplet compatibility
7375
// check
7476
scalar iSinTheta2 = static_cast<scalar>(1.) + lb.cotTheta() * lb.cotTheta();
7577
scalar scatteringInRegion2 = config.maxScatteringAngle2 * iSinTheta2;
7678
scatteringInRegion2 *= config.sigmaScattering * config.sigmaScattering;
7779

78-
// These two quantities are used as output parameters in
79-
// triplet_finding_helper::isCompatible but their values are irrelevant
80-
scalar curvature, impact_parameter;
81-
8280
// find the reference (start) index of the mid-top doublet container
8381
// item vector, where the doublets are recorded
8482
const unsigned int mt_start_idx = doublet_counts.m_posMidTop;
@@ -87,18 +85,16 @@ inline void count_triplets(
8785
// number of triplets per middle-bot doublet
8886
unsigned int num_triplets_per_mb = 0;
8987

90-
// iterate over mid-top doublets
91-
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
92-
const traccc::sp_location spT_loc = mid_top_doublet_device[i].sp2;
93-
94-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
95-
spacepoints.at(sp_device.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
88+
const unsigned int num_mt = mt_end_idx - mt_start_idx;
9689

90+
// iterate over mid-top doublets
91+
for (unsigned int ri = 0; ri < num_mt; ++ri) {
9792
// Apply the conformal transformation to middle-top doublet
98-
traccc::lin_circle lt = doublet_finding_helper::transform_coordinates<
99-
details::spacepoint_type::top>(spM, spT);
93+
const lin_circle& lt = mid_top_circles.at(ri + mt_start_idx);
10094

101-
// Check if mid-bot and mid-top doublets can form a triplet
95+
// These two quantities are used as output parameters in
96+
// triplet_finding_helper::isCompatible but their values are irrelevant
97+
scalar curvature, impact_parameter;
10298
if (triplet_finding_helper::isCompatible(
10399
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
104100
impact_parameter)) {
@@ -114,8 +110,9 @@ inline void count_triplets(
114110
const unsigned int posTriplets =
115111
nTriplets.fetch_add(num_triplets_per_mb);
116112

117-
mb_triplet_counter.push_back(
118-
{spB_loc, counter_link, num_triplets_per_mb, posTriplets});
113+
mb_triplet_counter.push_back({spB_loc, counter_link,
114+
num_triplets_per_mb, posTriplets,
115+
globalIndex});
119116
}
120117
}
121118

device/common/include/traccc/seeding/device/impl/find_triplets.ipp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ inline void find_triplets(
2525
const device_doublet_collection_types::const_view& mid_top_doublet_view,
2626
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
2727
const triplet_counter_collection_types::const_view& tc_view,
28+
vecmem::data::vector_view<const lin_circle> mid_bot_circle_view,
29+
vecmem::data::vector_view<const lin_circle> mid_top_circle_view,
2830
device_triplet_collection_types::view triplet_view) {
2931

3032
// Check if anything needs to be done.
@@ -44,6 +46,10 @@ inline void find_triplets(
4446
const traccc::details::spacepoint_grid_types::const_device sp_grid(sp_view);
4547
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
4648
spM_tc_view);
49+
const vecmem::device_vector<const lin_circle> mid_bot_circles(
50+
mid_bot_circle_view);
51+
const vecmem::device_vector<const lin_circle> mid_top_circles(
52+
mid_top_circle_view);
4753

4854
// Get the current work item information
4955
const triplet_counter mid_bot_counter = triplet_counts.at(globalIndex);
@@ -53,22 +59,17 @@ inline void find_triplets(
5359
doublet_counts.at(mid_bot_counter.spM_counter_link);
5460

5561
const sp_location spM_loc = spM_counter.spM;
56-
const sp_location spB_loc = mid_bot_counter.spB;
5762

5863
// middle spacepoint
5964
const edm::spacepoint_collection::const_device::const_proxy_type spM =
6065
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
6166

62-
// bottom spacepoint
63-
const edm::spacepoint_collection::const_device::const_proxy_type spB =
64-
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
65-
6667
// Set up the device result collection
6768
device_triplet_collection_types::device triplets(triplet_view);
6869

6970
// Apply the conformal transformation to middle-bot doublet
70-
const traccc::lin_circle lb = doublet_finding_helper::transform_coordinates<
71-
details::spacepoint_type::bottom>(spM, spB);
71+
const traccc::lin_circle lb =
72+
mid_bot_circles.at(mid_bot_counter.m_botMidIdx);
7273

7374
// Calculate some physical quantities required for triplet compatibility
7475
// check
@@ -77,10 +78,6 @@ inline void find_triplets(
7778
config.sigmaScattering *
7879
config.sigmaScattering;
7980

80-
// These two quantities are used as output parameters in
81-
// triplet_finding_helper::isCompatible but their values are irrelevant
82-
scalar curvature, impact_parameter;
83-
8481
// find the reference (start) index of the mid-top doublet collection
8582
// item vector, where the doublets are recorded
8683
const unsigned int mt_start_idx = doublet_count.m_posMidTop;
@@ -93,24 +90,18 @@ inline void find_triplets(
9390

9491
// iterate over mid-top doublets
9592
for (unsigned int i = mt_start_idx; i < mt_end_idx; ++i) {
96-
const sp_location spT_loc = mid_top_doublet_device[i].sp2;
97-
98-
const edm::spacepoint_collection::const_device::const_proxy_type spT =
99-
spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
100-
10193
// Apply the conformal transformation to middle-top doublet
102-
const traccc::lin_circle lt =
103-
doublet_finding_helper::transform_coordinates<
104-
details::spacepoint_type::top>(spM, spT);
94+
const traccc::lin_circle& lt = mid_top_circles.at(i);
10595

10696
// Check if mid-bot and mid-top doublets can form a triplet
97+
scalar curvature, impact_parameter;
10798
if (triplet_finding_helper::isCompatible(
10899
spM, lb, lt, config, iSinTheta2, scatteringInRegion2, curvature,
109100
impact_parameter)) {
110101

111102
// Add triplet to jagged vector
112103
triplets.at(posTriplets++) = device_triplet(
113-
{spT_loc, globalIndex, curvature,
104+
{mid_top_doublet_device[i].sp2, globalIndex, curvature,
114105
-impact_parameter * filter_config.impactWeightFactor,
115106
lb.Zo()});
116107
}

device/cuda/src/seeding/seed_finding.cu

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "traccc/edm/device/doublet_counter.hpp"
2121
#include "traccc/edm/device/seeding_global_counter.hpp"
2222
#include "traccc/edm/device/triplet_counter.hpp"
23+
#include "traccc/seeding/detail/spacepoint_type.hpp"
2324
#include "traccc/seeding/device/count_doublets.hpp"
2425
#include "traccc/seeding/device/count_triplets.hpp"
2526
#include "traccc/seeding/device/find_doublets.hpp"
@@ -65,20 +66,93 @@ __global__ void find_doublets(
6566
sp_grid, doublet_counter, mb_doublets, mt_doublets);
6667
}
6768

69+
__global__ void make_mid_bot_lincircles(
70+
device::device_doublet_collection_types::const_view mb_doublet_view,
71+
device::doublet_counter_collection_types::const_view doublet_count_view,
72+
edm::spacepoint_collection::const_view spacepoint_view,
73+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
74+
vecmem::data::vector_view<lin_circle> out_view) {
75+
76+
const device::device_doublet_collection_types::const_device doublets(
77+
mb_doublet_view);
78+
const device::doublet_counter_collection_types::const_device doublet_counts(
79+
doublet_count_view);
80+
const edm::spacepoint_collection::const_device spacepoints(spacepoint_view);
81+
traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view);
82+
vecmem::device_vector<lin_circle> out(out_view);
83+
84+
unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
85+
86+
if (tid >= doublets.size()) {
87+
return;
88+
}
89+
90+
const device::device_doublet dub = doublets.at(tid);
91+
const unsigned int counter_link = dub.counter_link;
92+
const device::doublet_counter count = doublet_counts.at(counter_link);
93+
const sp_location spM_loc = count.m_spM;
94+
const edm::spacepoint_collection::const_device::const_proxy_type spM =
95+
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
96+
const sp_location spB_loc = dub.sp2;
97+
const edm::spacepoint_collection::const_device::const_proxy_type spB =
98+
spacepoints.at(sp_grid.bin(spB_loc.bin_idx)[spB_loc.sp_idx]);
99+
100+
out.at(tid) = doublet_finding_helper::transform_coordinates<
101+
traccc::details::spacepoint_type::bottom>(spM, spB);
102+
}
103+
104+
__global__ void make_mid_top_lincircles(
105+
device::device_doublet_collection_types::const_view mt_doublet_view,
106+
device::doublet_counter_collection_types::const_view doublet_count_view,
107+
edm::spacepoint_collection::const_view spacepoint_view,
108+
traccc::details::spacepoint_grid_types::const_view sp_grid_view,
109+
vecmem::data::vector_view<lin_circle> out_view) {
110+
111+
const device::device_doublet_collection_types::const_device doublets(
112+
mt_doublet_view);
113+
const device::doublet_counter_collection_types::const_device doublet_counts(
114+
doublet_count_view);
115+
const edm::spacepoint_collection::const_device spacepoints(spacepoint_view);
116+
traccc::details::spacepoint_grid_types::const_device sp_grid(sp_grid_view);
117+
vecmem::device_vector<lin_circle> out(out_view);
118+
119+
unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
120+
121+
if (tid >= doublets.size()) {
122+
return;
123+
}
124+
125+
const device::device_doublet dub = doublets.at(tid);
126+
const unsigned int counter_link = dub.counter_link;
127+
const device::doublet_counter count = doublet_counts.at(counter_link);
128+
const sp_location spM_loc = count.m_spM;
129+
const edm::spacepoint_collection::const_device::const_proxy_type spM =
130+
spacepoints.at(sp_grid.bin(spM_loc.bin_idx)[spM_loc.sp_idx]);
131+
const sp_location spT_loc = dub.sp2;
132+
const edm::spacepoint_collection::const_device::const_proxy_type spT =
133+
spacepoints.at(sp_grid.bin(spT_loc.bin_idx)[spT_loc.sp_idx]);
134+
135+
out.at(tid) = doublet_finding_helper::transform_coordinates<
136+
traccc::details::spacepoint_type::top>(spM, spT);
137+
}
138+
68139
/// CUDA kernel for running @c traccc::device::count_triplets
69-
__global__ void count_triplets(
140+
__global__ __launch_bounds__(128) void count_triplets(
70141
seedfinder_config config,
71142
edm::spacepoint_collection::const_view spacepoints,
72143
traccc::details::spacepoint_grid_types::const_view sp_grid,
73144
device::doublet_counter_collection_types::const_view doublet_counter,
74145
device::device_doublet_collection_types::const_view mb_doublets,
75146
device::device_doublet_collection_types::const_view mt_doublets,
76147
device::triplet_counter_spM_collection_types::view spM_counter,
77-
device::triplet_counter_collection_types::view midBot_counter) {
148+
device::triplet_counter_collection_types::view midBot_counter,
149+
vecmem::data::vector_view<const lin_circle> midBot_circles,
150+
vecmem::data::vector_view<const lin_circle> midTop_circles) {
78151

79152
device::count_triplets(details::global_index1(), config, spacepoints,
80153
sp_grid, doublet_counter, mb_doublets, mt_doublets,
81-
spM_counter, midBot_counter);
154+
spM_counter, midBot_counter, midBot_circles,
155+
midTop_circles);
82156
}
83157

84158
/// CUDA kernel for running @c traccc::device::reduce_triplet_counts
@@ -92,19 +166,22 @@ __global__ void reduce_triplet_counts(
92166
}
93167

94168
/// CUDA kernel for running @c traccc::device::find_triplets
95-
__global__ void find_triplets(
169+
__global__ __launch_bounds__(128) void find_triplets(
96170
seedfinder_config config, seedfilter_config filter_config,
97171
edm::spacepoint_collection::const_view spacepoints,
98172
traccc::details::spacepoint_grid_types::const_view sp_grid,
99173
device::doublet_counter_collection_types::const_view doublet_counter,
100174
device::device_doublet_collection_types::const_view mt_doublets,
101175
device::triplet_counter_spM_collection_types::const_view spM_tc,
102176
device::triplet_counter_collection_types::const_view midBot_tc,
177+
vecmem::data::vector_view<const lin_circle> midBot_circles,
178+
vecmem::data::vector_view<const lin_circle> midTop_circles,
103179
device::device_triplet_collection_types::view triplet_view) {
104180

105181
device::find_triplets(details::global_index1(), config, filter_config,
106182
spacepoints, sp_grid, doublet_counter, mt_doublets,
107-
spM_tc, midBot_tc, triplet_view);
183+
spM_tc, midBot_tc, midBot_circles, midTop_circles,
184+
triplet_view);
108185
}
109186

110187
/// CUDA kernel for running @c traccc::device::update_triplet_weights
@@ -252,6 +329,33 @@ edm::seed_collection::buffer seed_finding::operator()(
252329
doublet_counter_buffer, doublet_buffer_mb, doublet_buffer_mt);
253330
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
254331

332+
vecmem::data::vector_buffer<lin_circle> midBotLinCircles{
333+
globalCounter_host->m_nMidBot, m_mr.main};
334+
m_copy.setup(midBotLinCircles)->wait();
335+
vecmem::data::vector_buffer<lin_circle> midTopLinCircles{
336+
globalCounter_host->m_nMidTop, m_mr.main};
337+
m_copy.setup(midBotLinCircles)->wait();
338+
339+
{
340+
const unsigned int nThreads = 128;
341+
const unsigned int nMidBotBlocks =
342+
(globalCounter_host->m_nMidBot + nThreads - 1) / nThreads;
343+
const unsigned int nMidTopBlocks =
344+
(globalCounter_host->m_nMidTop + nThreads - 1) / nThreads;
345+
346+
kernels::
347+
make_mid_bot_lincircles<<<nMidBotBlocks, nThreads, 0, stream>>>(
348+
doublet_buffer_mb, doublet_counter_buffer, spacepoints_view,
349+
g2_view, midBotLinCircles);
350+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
351+
352+
kernels::
353+
make_mid_top_lincircles<<<nMidTopBlocks, nThreads, 0, stream>>>(
354+
doublet_buffer_mt, doublet_counter_buffer, spacepoints_view,
355+
g2_view, midTopLinCircles);
356+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
357+
}
358+
255359
// Set up the triplet counter buffers
256360
device::triplet_counter_spM_collection_types::buffer
257361
triplet_counter_spM_buffer = {doublet_counter_buffer_size, m_mr.main};
@@ -265,7 +369,7 @@ edm::seed_collection::buffer seed_finding::operator()(
265369

266370
// Calculate the number of threads and thread blocks to run the doublet
267371
// counting kernel for.
268-
const unsigned int nTripletCountThreads = m_warp_size * 2;
372+
const unsigned int nTripletCountThreads = 128;
269373
const unsigned int nTripletCountBlocks =
270374
(globalCounter_host->m_nMidBot + nTripletCountThreads - 1) /
271375
nTripletCountThreads;
@@ -275,7 +379,7 @@ edm::seed_collection::buffer seed_finding::operator()(
275379
stream>>>(
276380
m_seedfinder_config, spacepoints_view, g2_view, doublet_counter_buffer,
277381
doublet_buffer_mb, doublet_buffer_mt, triplet_counter_spM_buffer,
278-
triplet_counter_midBot_buffer);
382+
triplet_counter_midBot_buffer, midBotLinCircles, midTopLinCircles);
279383
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
280384

281385
// Calculate the number of threads and thread blocks to run the triplet
@@ -309,7 +413,7 @@ edm::seed_collection::buffer seed_finding::operator()(
309413

310414
// Calculate the number of threads and thread blocks to run the triplet
311415
// finding kernel for.
312-
const unsigned int nTripletFindThreads = m_warp_size * 2;
416+
const unsigned int nTripletFindThreads = 128;
313417
const unsigned int nTripletFindBlocks =
314418
(m_copy.get_size(triplet_counter_midBot_buffer) + nTripletFindThreads -
315419
1) /
@@ -321,7 +425,7 @@ edm::seed_collection::buffer seed_finding::operator()(
321425
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
322426
doublet_counter_buffer, doublet_buffer_mt,
323427
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
324-
triplet_buffer);
428+
midBotLinCircles, midTopLinCircles, triplet_buffer);
325429
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
326430

327431
// Calculate the number of threads and thread blocks to run the weight

0 commit comments

Comments
 (0)