Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions device/common/include/traccc/device/array_insertion_mutex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

#include <cstdint>
#include <tuple>

#include "traccc/definitions/qualifiers.hpp"

namespace traccc::device {
/**
* @brief Encode the state of our parameter insertion mutex.
*/
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
const uint32_t size,
const float max) {
// Assert that the MSB of the size is zero
assert(size <= 0x7FFFFFFF);

const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
const uint32_t lo = std::bit_cast<uint32_t>(max);

return (static_cast<uint64_t>(hi) << 32) | lo;
}

/**
* @brief Decode the state of our parameter insertion mutex.
*/
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
decode_insertion_mutex(const uint64_t val) {
const uint32_t hi = static_cast<uint32_t>(val >> 32);
const uint32_t lo = val & 0xFFFFFFFF;

return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
std::bit_cast<float>(lo)};
}
} // namespace traccc::device
43 changes: 7 additions & 36 deletions device/common/include/traccc/finding/device/impl/find_tracks.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif

// Project include(s).
#include "traccc/device/array_insertion_mutex.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
#include "traccc/fitting/kalman_filter/is_line_visitor.hpp"
#include "traccc/fitting/status_codes.hpp"
Expand All @@ -34,35 +35,6 @@

namespace traccc::device {

namespace details {
/**
* @brief Encode the state of our parameter insertion mutex.
*/
TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked,
const uint32_t size,
const float max) {
// Assert that the MSB of the size is zero
assert(size <= 0x7FFFFFFF);

const uint32_t hi = size | (locked ? 0x80000000 : 0x0);
const uint32_t lo = std::bit_cast<uint32_t>(max);

return (static_cast<uint64_t>(hi) << 32) | lo;
}

/**
* @brief Decode the state of our parameter insertion mutex.
*/
TRACCC_HOST_DEVICE inline std::tuple<bool, uint32_t, float>
decode_insertion_mutex(const uint64_t val) {
const uint32_t hi = static_cast<uint32_t>(val >> 32);
const uint32_t lo = val & 0xFFFFFFFF;

return {static_cast<bool>(hi & 0x80000000), (hi & 0x7FFFFFFF),
std::bit_cast<float>(lo)};
}
} // namespace details

template <typename detector_t, concepts::thread_id1 thread_id_t,
concepts::barrier barrier_t>
TRACCC_HOST_DEVICE inline void find_tracks(
Expand Down Expand Up @@ -108,7 +80,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
}

shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] =
details::encode_insertion_mutex(false, 0, 0.f);
encode_insertion_mutex(false, 0, 0.f);

barrier.blockBarrier();

Expand Down Expand Up @@ -339,8 +311,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
* currently operating on the array guarded.
*/
unsigned long long int assumed = *mutex_ptr;
auto [locked, size, max] =
details::decode_insertion_mutex(assumed);
auto [locked, size, max] = decode_insertion_mutex(assumed);

/*
* If the array is already full _and_ our parameter has a
Expand All @@ -357,7 +328,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
* locked.
*/
if (result.has_value() && !locked) {
desired = details::encode_insertion_mutex(true, size, max);
desired = encode_insertion_mutex(true, size, max);

/*
* Attempt to CAS the mutex with the same value as before
Expand Down Expand Up @@ -496,8 +467,8 @@ TRACCC_HOST_DEVICE inline void find_tracks(
unsigned long long,
vecmem::device_address_space::local>(*mutex_ptr)
.compare_exchange_strong(
desired, details::encode_insertion_mutex(
false, new_size, new_max));
desired, encode_insertion_mutex(false, new_size,
new_max));

assert(cas_result);
}
Expand Down Expand Up @@ -563,7 +534,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
* at which this block will write.
*/
if (in_param_is_live) {
local_num_params = std::get<1>(details::decode_insertion_mutex(
local_num_params = std::get<1>(decode_insertion_mutex(
shared_payload
.shared_insertion_mutex[thread_id.getLocalThreadIdX()]));
/*
Expand Down
Loading