|
24 | 24 |
|
25 | 25 | #include <atomic> |
26 | 26 | #include <chrono> |
| 27 | +#include <functional> |
27 | 28 | #include <thread> |
| 29 | +#include <type_traits> |
28 | 30 |
|
29 | 31 | namespace rocprofiler |
30 | 32 | { |
31 | 33 | namespace rocattach |
32 | 34 | { |
33 | | -template <typename T> |
| 35 | +// Blocks until predicate(flag) == true or timeout_ms milliseconds have elapsed. |
| 36 | +// Returns true if predicate(flag) was true |
| 37 | +// Returns false if timeout occurred |
| 38 | +template <typename Tp, typename PredicateT> |
34 | 39 | bool |
35 | | -wait_for(std::atomic<T>& flag, T condition, size_t timeout_ms, bool equal) |
| 40 | +wait_for(std::atomic<Tp>& flag, size_t timeout_ms, PredicateT&& predicate) |
36 | 41 | { |
37 | | - auto cond_check = [&]() { |
38 | | - if(equal) return flag.load() == condition; |
39 | | - return flag.load() != condition; |
40 | | - }; |
| 42 | + static_assert(std::is_invocable<PredicateT, std::atomic<Tp>&>::value, "Invalid predicate"); |
| 43 | + using predicate_return_type = typename std::invoke_result<PredicateT, std::atomic<Tp>&>::type; |
| 44 | + static_assert(std::is_same<predicate_return_type, bool>::value, |
| 45 | + "Predicate must return boolean"); |
| 46 | + |
41 | 47 | auto start_time = std::chrono::steady_clock::now(); |
42 | 48 | auto timeout_duration = std::chrono::milliseconds(timeout_ms); |
43 | 49 | auto end_time = start_time + timeout_duration; |
44 | 50 | while(std::chrono::steady_clock::now() < end_time) |
45 | 51 | { |
46 | | - if(cond_check()) |
| 52 | + if(std::invoke(std::forward<PredicateT>(predicate), std::forward<std::atomic<Tp>&>(flag))) |
47 | 53 | { |
48 | 54 | return true; |
49 | 55 | } |
50 | 56 | std::this_thread::yield(); |
51 | 57 | } |
52 | 58 | // Last chance check in case we were scheduled after timeout |
53 | | - return cond_check(); |
| 59 | + return std::invoke(std::forward<PredicateT>(predicate), std::forward<std::atomic<Tp>&>(flag)); |
54 | 60 | } |
55 | | -// Blocks until flag is NOT equal to condition or timeout_ms milliseconds have elapsed. |
| 61 | +// Blocks until flag is NOT equal to value or timeout_ms milliseconds have elapsed. |
56 | 62 | // Returns true if the flag is not equal |
57 | 63 | // Returns false if timeout occurred |
58 | 64 | template <typename T> |
59 | 65 | bool |
60 | | -wait_for_ne(std::atomic<T>& flag, T condition, size_t timeout_ms) |
| 66 | +wait_for_ne(std::atomic<T>& flag, T value, size_t timeout_ms) |
61 | 67 | { |
62 | | - return wait_for(flag, condition, timeout_ms, false); |
| 68 | + auto predicate = [value](std::atomic<T>& a) { return a.load() != value; }; |
| 69 | + return wait_for(flag, timeout_ms, predicate); |
63 | 70 | } |
64 | | -// Blocks until flag is equal to condition or timeout_ms milliseconds have elapsed. |
| 71 | +// Blocks until flag is equal to value or timeout_ms milliseconds have elapsed. |
65 | 72 | // Returns true if the flag is equal |
66 | 73 | // Returns false if timeout occurred |
67 | 74 | template <typename T> |
68 | 75 | bool |
69 | | -wait_for_eq(std::atomic<T>& flag, T condition, size_t timeout_ms) |
| 76 | +wait_for_eq(std::atomic<T>& flag, T value, size_t timeout_ms) |
70 | 77 | { |
71 | | - return wait_for(flag, condition, timeout_ms, true); |
| 78 | + auto predicate = [value](std::atomic<T>& a) { return a.load() == value; }; |
| 79 | + return wait_for(flag, timeout_ms, predicate); |
72 | 80 | } |
73 | | - |
74 | 81 | } // namespace rocattach |
75 | 82 | } // namespace rocprofiler |
0 commit comments