Skip to content

Commit 520baad

Browse files
committed
Move insertion mutex code to dedicated header
The code that is used by the `find_tracks` kernel to allow for insertion into sorted arrays without race conditions can be applied for other kernels too, and so it would be nice to lift this code into its own header.
1 parent e8e0776 commit 520baad

File tree

2 files changed

+49
-36
lines changed

2 files changed

+49
-36
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
#include <cstdint>
11+
#include <tuple>
12+
13+
#include "traccc/definitions/qualifiers.hpp"
14+
15+
namespace traccc::device {
16+
/**
17+
* @brief Encode the state of our parameter insertion mutex.
18+
*/
19+
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
20+
const uint32_t size,
21+
const float max) {
22+
// Assert that the MSB of the size is zero
23+
assert(size <= 0x7FFFFFFF);
24+
25+
const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
26+
const uint32_t lo = std::bit_cast<uint32_t>(max);
27+
28+
return (static_cast<uint64_t>(hi) << 32) | lo;
29+
}
30+
31+
/**
32+
* @brief Decode the state of our parameter insertion mutex.
33+
*/
34+
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
35+
decode_insertion_mutex(const uint64_t val) {
36+
const uint32_t hi = static_cast<uint32_t>(val >> 32);
37+
const uint32_t lo = val & 0xFFFFFFFF;
38+
39+
return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
40+
std::bit_cast<float>(lo)};
41+
}
42+
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/find_tracks.ipp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#endif
2222

2323
// Project include(s).
24+
#include "traccc/device/array_insertion_mutex.hpp"
2425
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
2526
#include "traccc/fitting/kalman_filter/is_line_visitor.hpp"
2627
#include "traccc/fitting/status_codes.hpp"
@@ -34,35 +35,6 @@
3435

3536
namespace traccc::device {
3637

37-
namespace details {
38-
/**
39-
* @brief Encode the state of our parameter insertion mutex.
40-
*/
41-
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
42-
const uint32_t size,
43-
const float max) {
44-
// Assert that the MSB of the size is zero
45-
assert(size <= 0x7FFFFFFF);
46-
47-
const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
48-
const uint32_t lo = std::bit_cast<uint32_t>(max);
49-
50-
return (static_cast<uint64_t>(hi) << 32) | lo;
51-
}
52-
53-
/**
54-
* @brief Decode the state of our parameter insertion mutex.
55-
*/
56-
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
57-
decode_insertion_mutex(const uint64_t val) {
58-
const uint32_t hi = static_cast<uint32_t>(val >> 32);
59-
const uint32_t lo = val & 0xFFFFFFFF;
60-
61-
return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
62-
std::bit_cast<float>(lo)};
63-
}
64-
} // namespace details
65-
6638
template <typename detector_t, concepts::thread_id1 thread_id_t,
6739
concepts::barrier barrier_t>
6840
TRACCC_HOST_DEVICE inline void find_tracks(
@@ -108,7 +80,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
10880
}
10981

11082
shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] =
111-
details::encode_insertion_mutex(false, 0, 0.f);
83+
encode_insertion_mutex(false, 0, 0.f);
11284

11385
barrier.blockBarrier();
11486

@@ -339,8 +311,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
339311
* currently operating on the array guarded.
340312
*/
341313
unsigned long long int assumed = *mutex_ptr;
342-
auto [locked, size, max] =
343-
details::decode_insertion_mutex(assumed);
314+
auto [locked, size, max] = decode_insertion_mutex(assumed);
344315

345316
/*
346317
* If the array is already full _and_ our parameter has a
@@ -357,7 +328,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
357328
* locked.
358329
*/
359330
if (result.has_value() && !locked) {
360-
desired = details::encode_insertion_mutex(true, size, max);
331+
desired = encode_insertion_mutex(true, size, max);
361332

362333
/*
363334
* Attempt to CAS the mutex with the same value as before
@@ -496,8 +467,8 @@ TRACCC_HOST_DEVICE inline void find_tracks(
496467
unsigned long long,
497468
vecmem::device_address_space::local>(*mutex_ptr)
498469
.compare_exchange_strong(
499-
desired, details::encode_insertion_mutex(
500-
false, new_size, new_max));
470+
desired, encode_insertion_mutex(false, new_size,
471+
new_max));
501472

502473
assert(cas_result);
503474
}
@@ -563,7 +534,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
563534
* at which this block will write.
564535
*/
565536
if (in_param_is_live) {
566-
local_num_params = std::get<1>(details::decode_insertion_mutex(
537+
local_num_params = std::get<1>(decode_insertion_mutex(
567538
shared_payload
568539
.shared_insertion_mutex[thread_id.getLocalThreadIdX()]));
569540
/*

0 commit comments

Comments
 (0)