From 520baad2e76a5d0fa887d2213e4c88c7d4b6d66f Mon Sep 17 00:00:00 2001 From: Stephen Nicholas Swatman Date: Thu, 24 Jul 2025 17:28:21 +0200 Subject: [PATCH] Move insertion mutex code to dedicated header The code that is used by the `find_tracks` kernel to allow for insertion into sorted arrays without race conditions can be applied for other kernels too, and so it would be nice to lift this code into its own header. --- .../traccc/device/array_insertion_mutex.hpp | 42 ++++++++++++++++++ .../finding/device/impl/find_tracks.ipp | 43 +++---------------- 2 files changed, 49 insertions(+), 36 deletions(-) create mode 100644 device/common/include/traccc/device/array_insertion_mutex.hpp diff --git a/device/common/include/traccc/device/array_insertion_mutex.hpp b/device/common/include/traccc/device/array_insertion_mutex.hpp new file mode 100644 index 0000000000..6e2b6e5d4a --- /dev/null +++ b/device/common/include/traccc/device/array_insertion_mutex.hpp @@ -0,0 +1,42 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +#include +#include + +#include "traccc/definitions/qualifiers.hpp" + +namespace traccc::device { +/** + * @brief Encode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked, + const uint32_t size, + const float max) { + // Assert that the MSB of the size is zero + assert(size <= 0x7FFFFFFF); + + const uint32_t hi = size | (locked ? 0x80000000 : 0x0); + const uint32_t lo = std::bit_cast(max); + + return (static_cast(hi) << 32) | lo; +} + +/** + * @brief Decode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline std::tuple +decode_insertion_mutex(const uint64_t val) { + const uint32_t hi = static_cast(val >> 32); + const uint32_t lo = val & 0xFFFFFFFF; + + return {static_cast(hi & 0x80000000), (hi & 0x7FFFFFFF), + std::bit_cast(lo)}; +} +} // namespace traccc::device diff --git a/device/common/include/traccc/finding/device/impl/find_tracks.ipp b/device/common/include/traccc/finding/device/impl/find_tracks.ipp index 801ecd8527..ab21f177b8 100644 --- a/device/common/include/traccc/finding/device/impl/find_tracks.ipp +++ b/device/common/include/traccc/finding/device/impl/find_tracks.ipp @@ -21,6 +21,7 @@ #endif // Project include(s). +#include "traccc/device/array_insertion_mutex.hpp" #include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp" #include "traccc/fitting/kalman_filter/is_line_visitor.hpp" #include "traccc/fitting/status_codes.hpp" @@ -34,35 +35,6 @@ namespace traccc::device { -namespace details { -/** - * @brief Encode the state of our parameter insertion mutex. - */ -TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked, - const uint32_t size, - const float max) { - // Assert that the MSB of the size is zero - assert(size <= 0x7FFFFFFF); - - const uint32_t hi = size | (locked ? 0x80000000 : 0x0); - const uint32_t lo = std::bit_cast(max); - - return (static_cast(hi) << 32) | lo; -} - -/** - * @brief Decode the state of our parameter insertion mutex. - */ -TRACCC_HOST_DEVICE inline std::tuple -decode_insertion_mutex(const uint64_t val) { - const uint32_t hi = static_cast(val >> 32); - const uint32_t lo = val & 0xFFFFFFFF; - - return {static_cast(hi & 0x80000000), (hi & 0x7FFFFFFF), - std::bit_cast(lo)}; -} -} // namespace details - template TRACCC_HOST_DEVICE inline void find_tracks( @@ -108,7 +80,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( } shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] = - details::encode_insertion_mutex(false, 0, 0.f); + encode_insertion_mutex(false, 0, 0.f); barrier.blockBarrier(); @@ -339,8 +311,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * currently operating on the array guarded. */ unsigned long long int assumed = *mutex_ptr; - auto [locked, size, max] = - details::decode_insertion_mutex(assumed); + auto [locked, size, max] = decode_insertion_mutex(assumed); /* * If the array is already full _and_ our parameter has a @@ -357,7 +328,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * locked. */ if (result.has_value() && !locked) { - desired = details::encode_insertion_mutex(true, size, max); + desired = encode_insertion_mutex(true, size, max); /* * Attempt to CAS the mutex with the same value as before @@ -496,8 +467,8 @@ TRACCC_HOST_DEVICE inline void find_tracks( unsigned long long, vecmem::device_address_space::local>(*mutex_ptr) .compare_exchange_strong( - desired, details::encode_insertion_mutex( - false, new_size, new_max)); + desired, encode_insertion_mutex(false, new_size, + new_max)); assert(cas_result); } @@ -563,7 +534,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( * at which this block will write. */ if (in_param_is_live) { - local_num_params = std::get<1>(details::decode_insertion_mutex( + local_num_params = std::get<1>(decode_insertion_mutex( shared_payload .shared_insertion_mutex[thread_id.getLocalThreadIdX()])); /*