diff --git a/device/common/include/traccc/device/array_insertion_mutex.hpp b/device/common/include/traccc/device/array_insertion_mutex.hpp new file mode 100644 index 0000000000..6e2b6e5d4a --- /dev/null +++ b/device/common/include/traccc/device/array_insertion_mutex.hpp @@ -0,0 +1,42 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +#include +#include + +#include "traccc/definitions/qualifiers.hpp" + +namespace traccc::device { +/** + * @brief Encode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked, + const uint32_t size, + const float max) { + // Assert that the MSB of the size is zero + assert(size <= 0x7FFFFFFF); + + const uint32_t hi = size | (locked ? 0x80000000 : 0x0); + const uint32_t lo = std::bit_cast(max); + + return (static_cast(hi) << 32) | lo; +} + +/** + * @brief Decode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline std::tuple +decode_insertion_mutex(const uint64_t val) { + const uint32_t hi = static_cast(val >> 32); + const uint32_t lo = val & 0xFFFFFFFF; + + return {static_cast(hi & 0x80000000), (hi & 0x7FFFFFFF), + std::bit_cast(lo)}; +} +} // namespace traccc::device diff --git a/device/common/include/traccc/finding/device/impl/find_tracks.ipp b/device/common/include/traccc/finding/device/impl/find_tracks.ipp index 801ecd8527..ab21f177b8 100644 --- a/device/common/include/traccc/finding/device/impl/find_tracks.ipp +++ b/device/common/include/traccc/finding/device/impl/find_tracks.ipp @@ -21,6 +21,7 @@ #endif // Project include(s). +#include "traccc/device/array_insertion_mutex.hpp" #include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp" #include "traccc/fitting/kalman_filter/is_line_visitor.hpp" #include "traccc/fitting/status_codes.hpp" @@ -34,35 +35,6 @@ namespace traccc::device { -namespace details { -/** - * @brief Encode the state of our parameter insertion mutex. - */ -TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked, - const uint32_t size, - const float max) { - // Assert that the MSB of the size is zero - assert(size <= 0x7FFFFFFF); - - const uint32_t hi = size | (locked ? 0x80000000 : 0x0); - const uint32_t lo = std::bit_cast(max); - - return (static_cast(hi) << 32) | lo; -} - -/** - * @brief Decode the state of our parameter insertion mutex. - */ -TRACCC_HOST_DEVICE inline std::tuple -decode_insertion_mutex(const uint64_t val) { - const uint32_t hi = static_cast(val >> 32); - const uint32_t lo = val & 0xFFFFFFFF; - - return {static_cast(hi & 0x80000000), (hi & 0x7FFFFFFF), - std::bit_cast(lo)}; -} -} // namespace details - template TRACCC_HOST_DEVICE inline void find_tracks( @@ -108,7 +80,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( } shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] = - details::encode_insertion_mutex(false, 0, 0.f); + encode_insertion_mutex(false, 0, 0.f); barrier.blockBarrier(); @@ -339,8 +311,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * currently operating on the array guarded. */ unsigned long long int assumed = *mutex_ptr; - auto [locked, size, max] = - details::decode_insertion_mutex(assumed); + auto [locked, size, max] = decode_insertion_mutex(assumed); /* * If the array is already full _and_ our parameter has a @@ -357,7 +328,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * locked. */ if (result.has_value() && !locked) { - desired = details::encode_insertion_mutex(true, size, max); + desired = encode_insertion_mutex(true, size, max); /* * Attempt to CAS the mutex with the same value as before @@ -496,8 +467,8 @@ TRACCC_HOST_DEVICE inline void find_tracks( unsigned long long, vecmem::device_address_space::local>(*mutex_ptr) .compare_exchange_strong( - desired, details::encode_insertion_mutex( - false, new_size, new_max)); + desired, encode_insertion_mutex(false, new_size, + new_max)); assert(cas_result); } @@ -563,7 +534,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * at which this block will write. */ if (in_param_is_live) { - local_num_params = std::get<1>(details::decode_insertion_mutex( + local_num_params = std::get<1>(decode_insertion_mutex( shared_payload .shared_insertion_mutex[thread_id.getLocalThreadIdX()])); /*