@@ -41,17 +41,28 @@ TRACCC_DEVICE inline bool find_valid_index(
4141 return false ;
4242}
4343
44- __global__ void rearrange_tracks (device::rearrange_tracks_payload payload) {
44+ __launch_bounds__ (1024 ) __global__
45+ void rearrange_tracks (device::rearrange_tracks_payload payload) {
4546
4647 if (*(payload.terminate ) == 1 || *(payload.n_updated_tracks ) == 0 ) {
4748 return ;
4849 }
4950
50- auto gid = threadIdx .x + blockIdx .x * blockDim .x ;
51+ auto gid = threadIdx .x / nThreads_per_track +
52+ blockIdx .x * (blockDim .x / nThreads_per_track);
5153 const unsigned int n_accepted = *(payload.n_accepted );
5254
53- if (gid >= n_accepted) {
54- return ;
55+ auto N = *(payload.n_updated_tracks );
56+
57+ int neff_threads = (N + nThreads_per_track - 1 ) / nThreads_per_track;
58+
59+ if (neff_threads > nThreads_per_track) {
60+ neff_threads = nThreads_per_track;
61+ }
62+
63+ bool is_valid_thread = true ;
64+ if (threadIdx .x % nThreads_per_track >= neff_threads || gid >= n_accepted) {
65+ is_valid_thread = false ;
5566 }
5667
5768 vecmem::device_vector<const unsigned int > sorted_ids (
@@ -68,116 +79,139 @@ __global__ void rearrange_tracks(device::rearrange_tracks_payload payload) {
6879 vecmem::device_vector<unsigned int > temp_sorted_ids (
6980 payload.temp_sorted_ids_view );
7081
71- const auto tid = sorted_ids[gid];
72- auto rel_sh_ref = rel_shared[tid];
73- auto pval_ref = pvals[tid];
74- int shifted_idx = static_cast <int >(gid);
75- auto N = *(payload.n_updated_tracks );
82+ __shared__ int shifted_indices[1024 ];
83+ auto & shifted_idx = shifted_indices[threadIdx .x / nThreads_per_track];
84+ unsigned int tid = std::numeric_limits<unsigned int >::max ();
7685
77- if (is_updated[tid] ) {
86+ if (is_valid_thread ) {
7887
79- if (gid > 0 ) {
88+ tid = sorted_ids[gid];
89+ auto rel_sh_ref = rel_shared[tid];
90+ auto pval_ref = pvals[tid];
8091
81- unsigned int left = 0 ;
82- unsigned int right = gid;
92+ shifted_idx = static_cast <int >(gid);
8393
84- bool first_iteration = true ;
85- while (right > left) {
94+ int stride = (N + neff_threads - 1 ) / neff_threads;
8695
87- const bool find_left =
88- find_valid_index (left, 0 , gid, sorted_ids, is_updated );
96+ int ini_idx = stride * ( threadIdx . x % nThreads_per_track);
97+ int fin_idx = std::min (ini_idx + stride, static_cast < int >(N) );
8998
90- if (!find_left) {
91- break ;
92- }
99+ if (is_updated[tid]) {
93100
94- const bool find_right =
95- find_valid_index (right, 0 , gid, sorted_ids, is_updated);
101+ if (gid > 0 ) {
96102
97- if (!find_right) {
98- break ;
99- }
103+ unsigned int left = 0 ;
104+ unsigned int right = gid;
100105
101- if (first_iteration) {
102- auto rel_sh = rel_shared[sorted_ids[right]];
103- auto pval = pvals[sorted_ids[right]];
106+ bool first_iteration = true ;
104107
105- if (rel_sh < rel_sh_ref ||
106- (rel_sh == rel_sh_ref && pval >= pval_ref)) {
107- left = gid;
108- break ;
109- }
110- }
108+ if (threadIdx .x % nThreads_per_track == 0 ) {
109+
110+ while (right > left) {
111+
112+ const bool find_left = find_valid_index (
113+ left, 0 , gid, sorted_ids, is_updated);
114+
115+ if (!find_left) {
116+ break ;
117+ }
118+
119+ const bool find_right = find_valid_index (
120+ right, 0 , gid, sorted_ids, is_updated);
121+
122+ if (!find_right) {
123+ break ;
124+ }
111125
112- first_iteration = false ;
126+ if (first_iteration) {
127+ const auto right_idx = sorted_ids[right];
128+ auto rel_sh = rel_shared[right_idx];
129+ auto pval = pvals[right_idx];
113130
114- unsigned int mid = left + (right - left) / 2 ;
131+ if (rel_sh < rel_sh_ref ||
132+ (rel_sh == rel_sh_ref && pval >= pval_ref)) {
133+ left = gid;
134+ break ;
135+ }
136+ }
115137
116- const bool find_mid = find_valid_index (mid, left, right - 1 ,
117- sorted_ids, is_updated);
138+ first_iteration = false ;
118139
119- if (find_mid) {
140+ unsigned int mid = left + (right - left) / 2 ;
120141
121- auto rel_sh = rel_shared[sorted_ids[mid]];
122- auto pval = pvals[sorted_ids[ mid]] ;
142+ const bool find_mid = find_valid_index (
143+ mid, left, right - 1 , sorted_ids, is_updated) ;
123144
124- if (rel_sh < rel_sh_ref ||
125- (rel_sh == rel_sh_ref && pval >= pval_ref)) {
145+ if (find_mid) {
126146
127- left = mid + 1 ;
128- } else {
129- right = mid;
147+ const auto mid_idx = sorted_ids[mid];
148+ auto rel_sh = rel_shared[mid_idx];
149+ auto pval = pvals[mid_idx];
150+
151+ if (rel_sh < rel_sh_ref ||
152+ (rel_sh == rel_sh_ref && pval >= pval_ref)) {
153+
154+ left = mid + 1 ;
155+ } else {
156+ right = mid;
157+ }
158+ }
130159 }
131- }
132- }
133160
134- int delta = 0 ;
161+ int delta = delta =
162+ gid - left - (prefix_sums[gid] - prefix_sums[left]);
135163
136- if (is_updated[sorted_ids[left]]) {
137- delta = gid - left - (prefix_sums[gid] - prefix_sums[left]);
138- } else {
139- delta = gid - left - (prefix_sums[gid] - prefix_sums[left] - 1 );
140- }
164+ if (!is_updated[sorted_ids[left]]) {
165+ delta++;
166+ }
141167
142- shifted_idx -= delta;
143- }
168+ atomicAdd (&shifted_idx, -delta);
169+ }
170+ }
144171
145- for (int i = 0 ; i < N ; i++) {
172+ for (int i = ini_idx ; i < fin_idx ; i++) {
146173
147- auto id = updated_tracks[i];
174+ auto id = updated_tracks[i];
148175
149- if (inverted_ids[id] < gid) {
150- shifted_idx--;
176+ if (inverted_ids[id] < gid) {
177+ atomicAdd (&shifted_idx, -1 );
178+ }
151179 }
152- }
153180
154- int offset = 0 ;
155- for (int i = 0 ; i < N; i++) {
156- if (updated_tracks[i] == tid) {
157- offset = i;
158- break ;
181+ int offset = 0 ;
182+ for (int i = ini_idx; i < fin_idx; i++) {
183+ if (updated_tracks[i] == tid) {
184+ offset = i;
185+ break ;
186+ }
159187 }
160- }
161- shifted_idx += offset;
188+ if (offset != 0 ) {
189+ atomicAdd (&shifted_idx, offset);
190+ }
191+ } else {
162192
163- } else {
164- for (int i = 0 ; i < N; i++) {
193+ for (int i = ini_idx; i < fin_idx; i++) {
165194
166- auto id = updated_tracks[i];
167- auto rel_sh = rel_shared[id];
168- auto pval = pvals[id];
195+ auto id = updated_tracks[i];
196+ auto rel_sh = rel_shared[id];
197+ auto pval = pvals[id];
169198
170- if (inverted_ids[id] > gid) {
171- if (rel_sh < rel_sh_ref) {
172- shifted_idx++;
173- } else if (rel_sh == rel_sh_ref && pval > pval_ref) {
174- shifted_idx++;
199+ if (inverted_ids[id] > gid) {
200+ if (rel_sh < rel_sh_ref) {
201+ atomicAdd (&shifted_idx, 1 );
202+ } else if (rel_sh == rel_sh_ref && pval > pval_ref) {
203+ atomicAdd (&shifted_idx, 1 );
204+ }
175205 }
176206 }
177207 }
178208 }
179209
180- temp_sorted_ids.at (shifted_idx) = tid;
210+ __syncthreads ();
211+
212+ if (is_valid_thread && (threadIdx .x % nThreads_per_track) == 0 ) {
213+ temp_sorted_ids.at (shifted_idx) = tid;
214+ }
181215}
182216
183217} // namespace traccc::cuda::kernels
0 commit comments