Skip to content

Commit 8a78f14

Browse files
authored
Merge pull request #1092 from stephenswat/refactor/array_insertion_mutex
Move insertion mutex code to dedicated header
2 parents e8e0776 + 520baad commit 8a78f14

File tree

2 files changed

+49
-36
lines changed

2 files changed

+49
-36
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
#include <cstdint>
11+
#include <tuple>
12+
13+
#include "traccc/definitions/qualifiers.hpp"
14+
15+
namespace traccc::device {
16+
/**
17+
* @brief Encode the state of our parameter insertion mutex.
18+
*/
19+
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
20+
const uint32_t size,
21+
const float max) {
22+
// Assert that the MSB of the size is zero
23+
assert(size <= 0x7FFFFFFF);
24+
25+
const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
26+
const uint32_t lo = std::bit_cast<uint32_t>(max);
27+
28+
return (static_cast<uint64_t>(hi) << 32) | lo;
29+
}
30+
31+
/**
32+
* @brief Decode the state of our parameter insertion mutex.
33+
*/
34+
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
35+
decode_insertion_mutex(const uint64_t val) {
36+
const uint32_t hi = static_cast<uint32_t>(val >> 32);
37+
const uint32_t lo = val & 0xFFFFFFFF;
38+
39+
return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
40+
std::bit_cast<float>(lo)};
41+
}
42+
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/find_tracks.ipp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#endif
2222

2323
// Project include(s).
24+
#include "traccc/device/array_insertion_mutex.hpp"
2425
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
2526
#include "traccc/fitting/kalman_filter/is_line_visitor.hpp"
2627
#include "traccc/fitting/status_codes.hpp"
@@ -34,35 +35,6 @@
3435

3536
namespace traccc::device {
3637

37-
namespace details {
38-
/**
39-
* @brief Encode the state of our parameter insertion mutex.
40-
*/
41-
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
42-
const uint32_t size,
43-
const float max) {
44-
// Assert that the MSB of the size is zero
45-
assert(size <= 0x7FFFFFFF);
46-
47-
const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
48-
const uint32_t lo = std::bit_cast<uint32_t>(max);
49-
50-
return (static_cast<uint64_t>(hi) << 32) | lo;
51-
}
52-
53-
/**
54-
* @brief Decode the state of our parameter insertion mutex.
55-
*/
56-
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
57-
decode_insertion_mutex(const uint64_t val) {
58-
const uint32_t hi = static_cast<uint32_t>(val >> 32);
59-
const uint32_t lo = val & 0xFFFFFFFF;
60-
61-
return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
62-
std::bit_cast<float>(lo)};
63-
}
64-
} // namespace details
65-
6638
template <typename detector_t, concepts::thread_id1 thread_id_t,
6739
concepts::barrier barrier_t>
6840
TRACCC_HOST_DEVICE inline void find_tracks(
@@ -108,7 +80,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
10880
}
10981

11082
shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] =
111-
details::encode_insertion_mutex(false, 0, 0.f);
83+
encode_insertion_mutex(false, 0, 0.f);
11284

11385
barrier.blockBarrier();
11486

@@ -339,8 +311,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
339311
* currently operating on the array guarded.
340312
*/
341313
unsigned long long int assumed = *mutex_ptr;
342-
auto [locked, size, max] =
343-
details::decode_insertion_mutex(assumed);
314+
auto [locked, size, max] = decode_insertion_mutex(assumed);
344315

345316
/*
346317
* If the array is already full _and_ our parameter has a
@@ -357,7 +328,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
357328
* locked.
358329
*/
359330
if (result.has_value() && !locked) {
360-
desired = details::encode_insertion_mutex(true, size, max);
331+
desired = encode_insertion_mutex(true, size, max);
361332

362333
/*
363334
* Attempt to CAS the mutex with the same value as before
@@ -496,8 +467,8 @@ TRACCC_HOST_DEVICE inline void find_tracks(
496467
unsigned long long,
497468
vecmem::device_address_space::local>(*mutex_ptr)
498469
.compare_exchange_strong(
499-
desired, details::encode_insertion_mutex(
500-
false, new_size, new_max));
470+
desired, encode_insertion_mutex(false, new_size,
471+
new_max));
501472

502473
assert(cas_result);
503474
}
@@ -563,7 +534,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
563534
* at which this block will write.
564535
*/
565536
if (in_param_is_live) {
566-
local_num_params = std::get<1>(details::decode_insertion_mutex(
537+
local_num_params = std::get<1>(decode_insertion_mutex(
567538
shared_payload
568539
.shared_insertion_mutex[thread_id.getLocalThreadIdX()]));
569540
/*

0 commit comments

Comments
 (0)