Skip to content

Commit ad2508d

Browse files
committed
Test
1 parent 9457b4f commit ad2508d

File tree

4 files changed

+271
-17
lines changed

4 files changed

+271
-17
lines changed

core/include/traccc/seeding/detail/seeding_config.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ struct seedfilter_config {
194194
float deltaInvHelixDiameter = 0.00003f / unit<float>::mm;
195195
// the impact parameters (d0) is multiplied by this factor and subtracted
196196
// from weight
197-
float impactWeightFactor = 1.f;
197+
float impactWeightFactor = 0.5f;
198198
// seed weight increased by this value if a compatible seed has been found.
199-
float compatSeedWeight = 200.f;
199+
float compatSeedWeight = 2.5f;
200200
// minimum distance between compatible seeds to be considered for weight
201201
// boost
202202
float deltaRMin = 5.f * unit<float>::mm;

device/common/include/traccc/seeding/device/impl/select_seeds.ipp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ inline void select_seeds(
6767
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
6868
const triplet_counter_collection_types::const_view& tc_view,
6969
const device_triplet_collection_types::const_view& triplet_view,
70-
device_triplet* data, edm::seed_collection::view seed_view) {
70+
device_triplet* data,
71+
const vecmem::data::vector_view<const unsigned int> votes_per_triplet_view,
72+
edm::seed_collection::view seed_view) {
7173

7274
// Check if anything needs to be done.
7375
const triplet_counter_spM_collection_types::const_device triplet_counts_spM(
@@ -85,6 +87,8 @@ inline void select_seeds(
8587
sp_view);
8688

8789
const device_triplet_collection_types::const_device triplets(triplet_view);
90+
const vecmem::device_vector<const unsigned int> votes_per_triplet(
91+
votes_per_triplet_view);
8892
edm::seed_collection::device seeds_device(seed_view);
8993

9094
// Current work item = middle spacepoint
@@ -99,9 +103,14 @@ inline void select_seeds(
99103

100104
const unsigned int end_triplets_spM =
101105
spM_counter.posTriplets + spM_counter.m_nTriplets;
106+
102107
// iterate over the triplets in the bin
103108
for (unsigned int i = spM_counter.posTriplets; i < end_triplets_spM; ++i) {
104-
device_triplet aTriplet = triplets[i];
109+
const device_triplet& aTriplet = triplets[i];
110+
111+
if (votes_per_triplet.at(i) < 3) {
112+
continue;
113+
}
105114

106115
// spacepoints bottom and top for this triplet
107116
const unsigned int spB_idx = aTriplet.spB;
@@ -111,16 +120,6 @@ inline void select_seeds(
111120
const edm::spacepoint_collection::const_device::const_proxy_type spT =
112121
spacepoints.at(spT_idx);
113122

114-
// update weight of triplet
115-
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
116-
aTriplet.weight);
117-
118-
// check if it is a good triplet
119-
if (!seed_selecting_helper::single_seed_cut(filter_config, spM, spB,
120-
spT, aTriplet.weight)) {
121-
continue;
122-
}
123-
124123
// if the number of good triplets is larger than the threshold,
125124
// the triplet with the lowest weight is removed
126125
if (n_triplets_per_spM >= finder_config.maxSeedsPerSpM) {

device/common/include/traccc/seeding/device/select_seeds.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ inline void select_seeds(
4242
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
4343
const triplet_counter_collection_types::const_view& tc_view,
4444
const device_triplet_collection_types::const_view& triplet_view,
45-
triplet* data, edm::seed_collection::view seed_view);
45+
triplet* data,
46+
const vecmem::data::vector_view<const unsigned int> votes_per_triplet_view,
47+
edm::seed_collection::view seed_view);
4648

4749
} // namespace traccc::device
4850

device/cuda/src/seeding/seed_finding.cu

Lines changed: 255 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "traccc/edm/device/doublet_counter.hpp"
2121
#include "traccc/edm/device/seeding_global_counter.hpp"
2222
#include "traccc/edm/device/triplet_counter.hpp"
23+
#include "traccc/seeding/detail/singlet.hpp"
2324
#include "traccc/seeding/device/count_doublets.hpp"
2425
#include "traccc/seeding/device/count_triplets.hpp"
2526
#include "traccc/seeding/device/find_doublets.hpp"
@@ -37,6 +38,195 @@
3738

3839
namespace traccc::cuda {
3940
namespace kernels {
41+
namespace detail {
42+
/**
43+
* @brief Encode the state of our parameter insertion mutex.
44+
*/
45+
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
46+
const uint32_t size,
47+
const float max) {
48+
49+
// Assert that the MSB of the size is zero
50+
assert(size <= 0x7FFFFFFF);
51+
const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
52+
const uint32_t lo = std::bit_cast<uint32_t>(max);
53+
return (static_cast<uint64_t>(hi) << 32) | lo;
54+
}
55+
56+
/**
57+
* @brief Decode the state of our parameter insertion mutex.
58+
*/
59+
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
60+
decode_insertion_mutex(const uint64_t val) {
61+
const uint32_t hi = static_cast<uint32_t>(val >> 32);
62+
const uint32_t lo = val & 0xFFFFFFFF;
63+
return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
64+
std::bit_cast<float>(lo)};
65+
}
66+
} // namespace detail
67+
68+
__global__ inline void gather_best_triplets_per_spacepoint(
69+
const device::device_triplet_collection_types::const_view triplet_view,
70+
const traccc::details::spacepoint_grid_types::const_view grid_view,
71+
vecmem::data::vector_view<unsigned long long int> insertion_mutex_view,
72+
vecmem::data::vector_view<unsigned int> triplet_index_view,
73+
vecmem::data::vector_view<float> triplet_weight_view,
74+
const unsigned int max_num_triplets_per_spacepoint) {
75+
76+
const device::device_triplet_collection_types::const_device triplets(
77+
triplet_view);
78+
const traccc::details::spacepoint_grid_types::const_device grid(grid_view);
79+
80+
vecmem::device_vector<unsigned long long int> insertion_mutex(
81+
insertion_mutex_view);
82+
vecmem::device_vector<unsigned int> triplet_index(triplet_index_view);
83+
vecmem::device_vector<float> triplet_weight(triplet_weight_view);
84+
85+
unsigned int triplet_idx = blockIdx.x * blockDim.x + threadIdx.x;
86+
87+
scalar weight;
88+
bool need_to_write = true;
89+
unsigned int current_state = 0;
90+
unsigned int spacepoint_idx = 0;
91+
92+
auto get_idx_for_state = [&triplets, &grid,
93+
triplet_idx](unsigned int state) {
94+
const device::device_triplet& this_triplet = triplets.at(triplet_idx);
95+
if (state == 0) {
96+
return this_triplet.spB;
97+
} else if (state == 1) {
98+
return this_triplet.spM;
99+
} else {
100+
return this_triplet.spT;
101+
}
102+
};
103+
104+
if (triplet_idx < triplets.size()) {
105+
weight = triplets.at(triplet_idx).weight;
106+
spacepoint_idx = get_idx_for_state(0);
107+
} else {
108+
need_to_write = false;
109+
current_state = 3;
110+
}
111+
112+
while (__syncthreads_or(current_state < 3 || need_to_write)) {
113+
if (current_state < 3 || need_to_write) {
114+
if (need_to_write) {
115+
vecmem::device_atomic_ref<unsigned long long int> mutex(
116+
insertion_mutex.at(spacepoint_idx));
117+
118+
unsigned long long int assumed = mutex.load();
119+
unsigned long long int desired_set;
120+
121+
auto [locked, size, worst] =
122+
detail::decode_insertion_mutex(assumed);
123+
124+
if (need_to_write && size >= max_num_triplets_per_spacepoint &&
125+
weight <= worst) {
126+
need_to_write = false;
127+
}
128+
129+
bool holds_lock = false;
130+
131+
if (need_to_write && !locked) {
132+
desired_set =
133+
detail::encode_insertion_mutex(true, size, worst);
134+
if (mutex.compare_exchange_strong(assumed, desired_set)) {
135+
holds_lock = true;
136+
}
137+
}
138+
139+
if (holds_lock) {
140+
unsigned int new_size;
141+
unsigned int offset =
142+
spacepoint_idx * max_num_triplets_per_spacepoint;
143+
unsigned int out_idx;
144+
145+
if (size == max_num_triplets_per_spacepoint) {
146+
new_size = size;
147+
scalar worst_weight =
148+
std::numeric_limits<scalar>::max();
149+
for (unsigned int i = 0; i < size; ++i) {
150+
if (triplet_weight.at(offset + i) < worst_weight) {
151+
worst_weight = triplet_weight.at(offset + i);
152+
out_idx = i;
153+
}
154+
}
155+
} else {
156+
new_size = size + 1;
157+
out_idx = size;
158+
}
159+
160+
triplet_index.at(offset + out_idx) = triplet_idx;
161+
triplet_weight.at(offset + out_idx) = weight;
162+
163+
scalar new_worst = std::numeric_limits<scalar>::max();
164+
165+
for (unsigned int i = 0; i < new_size; ++i) {
166+
new_worst =
167+
std::min(new_worst, triplet_weight.at(offset + i));
168+
}
169+
170+
[[maybe_unused]] bool cas_result =
171+
mutex.compare_exchange_strong(
172+
desired_set, detail::encode_insertion_mutex(
173+
false, new_size, new_worst));
174+
assert(cas_result);
175+
176+
need_to_write = false;
177+
}
178+
}
179+
180+
if (!need_to_write && current_state < 3) {
181+
if (current_state < 2) {
182+
spacepoint_idx = get_idx_for_state(++current_state);
183+
need_to_write = true;
184+
} else {
185+
++current_state;
186+
}
187+
}
188+
}
189+
}
190+
}
191+
192+
__global__ inline void gather_spacepoint_votes(
193+
const vecmem::data::vector_view<const unsigned long long int>
194+
insertion_mutex_view,
195+
const vecmem::data::vector_view<const unsigned int> triplet_index_view,
196+
vecmem::data::vector_view<unsigned int> votes_per_triplet_view,
197+
const unsigned int max_num_triplets_per_spacepoint) {
198+
199+
unsigned int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
200+
unsigned int spacepoint_idx = thread_idx / max_num_triplets_per_spacepoint;
201+
unsigned int triplet_idx = thread_idx % max_num_triplets_per_spacepoint;
202+
203+
const vecmem::device_vector<const unsigned long long int> insertion_mutex(
204+
insertion_mutex_view);
205+
const vecmem::device_vector<const unsigned int> triplet_index(
206+
triplet_index_view);
207+
vecmem::device_vector<unsigned int> votes_per_triplet(
208+
votes_per_triplet_view);
209+
210+
if (spacepoint_idx >= insertion_mutex.size()) {
211+
return;
212+
}
213+
214+
auto [locked, size, worst] =
215+
detail::decode_insertion_mutex(insertion_mutex.at(spacepoint_idx));
216+
217+
if (size == 0) {
218+
return;
219+
}
220+
221+
unsigned int num_votes =
222+
(size + max_num_triplets_per_spacepoint - 1) / size;
223+
224+
if (triplet_idx < size) {
225+
vecmem::device_atomic_ref<unsigned int>(
226+
votes_per_triplet.at(triplet_index.at(thread_idx)))
227+
.fetch_add(num_votes);
228+
}
229+
}
40230

41231
/// CUDA kernel for running @c traccc::device::count_doublets
42232
__global__ void count_doublets(
@@ -134,6 +324,7 @@ __global__ void select_seeds(
134324
device::triplet_counter_spM_collection_types::const_view spM_tc,
135325
device::triplet_counter_collection_types::const_view midBot_tc,
136326
device::device_triplet_collection_types::view triplet_view,
327+
const vecmem::data::vector_view<const unsigned int> votes_per_triplet_view,
137328
edm::seed_collection::view seed_view) {
138329

139330
// Array for temporary storage of triplets for comparing within seed
@@ -145,7 +336,7 @@ __global__ void select_seeds(
145336

146337
device::select_seeds(details::global_index1(), finder_config, filter_config,
147338
spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
148-
dataPos, seed_view);
339+
dataPos, votes_per_triplet_view, seed_view);
149340
}
150341

151342
} // namespace kernels
@@ -184,6 +375,10 @@ edm::seed_collection::buffer seed_finding::operator()(
184375
return {0, m_mr.main};
185376
}
186377

378+
// Number of spacepoints including those not in the grid.
379+
const unsigned int num_spacepoints_total =
380+
m_copy.get_size(spacepoints_view);
381+
187382
// Set up the doublet counter buffer.
188383
device::doublet_counter_collection_types::buffer doublet_counter_buffer = {
189384
num_spacepoints, m_mr.main, vecmem::data::buffer_type::resizable};
@@ -341,6 +536,64 @@ edm::seed_collection::buffer seed_finding::operator()(
341536
triplet_buffer);
342537
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
343538

539+
vecmem::data::vector_buffer<unsigned int> votes_per_triplet_buffer(
540+
globalCounter_host->m_nTriplets, m_mr.main);
541+
m_copy.setup(votes_per_triplet_buffer)->wait();
542+
m_copy.memset(votes_per_triplet_buffer, 0)->wait();
543+
544+
{
545+
unsigned int num_votes_per_sp = 3;
546+
547+
vecmem::data::vector_buffer<unsigned int>
548+
best_triplets_per_spacepoint_index_buffer(
549+
num_votes_per_sp * num_spacepoints_total, m_mr.main);
550+
m_copy.setup(best_triplets_per_spacepoint_index_buffer)->wait();
551+
552+
vecmem::data::vector_buffer<unsigned long long int>
553+
best_triplets_per_spacepoint_insertion_mutex_buffer(
554+
num_spacepoints_total, m_mr.main);
555+
m_copy.setup(best_triplets_per_spacepoint_insertion_mutex_buffer)
556+
->wait();
557+
m_copy.memset(best_triplets_per_spacepoint_insertion_mutex_buffer, 0)
558+
->wait();
559+
560+
{
561+
vecmem::data::vector_buffer<float>
562+
best_triplets_per_spacepoint_weight_buffer(
563+
num_votes_per_sp * num_spacepoints_total, m_mr.main);
564+
m_copy.setup(best_triplets_per_spacepoint_weight_buffer)->wait();
565+
566+
unsigned int num_threads = 64;
567+
unsigned int num_blocks =
568+
(globalCounter_host->m_nTriplets + num_threads - 1) /
569+
num_threads;
570+
571+
kernels::gather_best_triplets_per_spacepoint<<<
572+
num_blocks, num_threads, 0, stream>>>(
573+
triplet_buffer, g2_view,
574+
best_triplets_per_spacepoint_insertion_mutex_buffer,
575+
best_triplets_per_spacepoint_index_buffer,
576+
best_triplets_per_spacepoint_weight_buffer, num_votes_per_sp);
577+
578+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
579+
}
580+
581+
{
582+
unsigned int num_threads = 512;
583+
unsigned int num_blocks =
584+
(num_votes_per_sp * num_spacepoints_total + num_threads - 1) /
585+
num_threads;
586+
587+
kernels::
588+
gather_spacepoint_votes<<<num_blocks, num_threads, 0, stream>>>(
589+
best_triplets_per_spacepoint_insertion_mutex_buffer,
590+
best_triplets_per_spacepoint_index_buffer,
591+
votes_per_triplet_buffer, num_votes_per_sp);
592+
593+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
594+
}
595+
}
596+
344597
// Create result object: collection of seeds
345598
edm::seed_collection::buffer seed_buffer(
346599
globalCounter_host->m_nTriplets, m_mr.main,
@@ -362,7 +615,7 @@ edm::seed_collection::buffer seed_finding::operator()(
362615
stream>>>(
363616
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
364617
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
365-
triplet_buffer, seed_buffer);
618+
triplet_buffer, votes_per_triplet_buffer, seed_buffer);
366619
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
367620

368621
return seed_buffer;

0 commit comments

Comments
 (0)