|
4 | 4 | */ |
5 | 5 | #pragma once |
6 | 6 |
|
| 7 | +#include <array> |
7 | 8 | #include <chrono> |
8 | 9 | #include <condition_variable> |
9 | 10 | #include <functional> |
10 | 11 | #include <mutex> |
11 | 12 | #include <optional> |
12 | 13 | #include <unordered_map> |
13 | | -#include <unordered_set> |
14 | 14 | #include <vector> |
15 | 15 |
|
16 | 16 | #include <rapidsmpf/communicator/communicator.hpp> |
@@ -44,13 +44,43 @@ namespace detail { |
44 | 44 | */ |
45 | 45 | class FinishCounter { |
46 | 46 | public: |
| 47 | + /** |
| 48 | + * @brief Callback function type called when a partition is finished. |
| 49 | + * |
| 50 | + * The callback receives the partition ID of the finished partition. |
| 51 | + * |
| 52 | + * @warning A callback must be fast and non-blocking and should not call any of the |
| 53 | + * `wait*` methods. And be very careful if acquiring locks. Ideally it should be used |
| 54 | + * to signal a separate thread to do the actual processing (eg. WaitHand). |
| 55 | + * |
| 56 | + * @note When a callback is registered, it will be identified by the |
| 57 | + * FinishedCbId returned. So, if a callback needs to be preemptively canceled, |
| 58 | + * the corresponding identifier needs to be provided. |
| 59 | + * |
| 60 | + * @note Every callback will be called as and when each partition is finished. If |
| 61 | + * there were finished partitions before the callback was registered, the callback |
| 62 | + * will be called for them immediately by the caller thread. Else, the callback will |
| 63 | + * be called by the progress thread (Therefore, it will be called |
| 64 | + * `n_local_partitions_` times in total). |
| 65 | + * |
| 66 | + * @note Caller needs to be careful when using both callbacks and wait* methods |
| 67 | + * together. |
| 68 | + */ |
| 69 | + using FinishedCallback = std::function<void(PartID)>; |
| 70 | + |
47 | 71 | /** |
48 | 72 | * @brief Construct a finish counter. |
49 | 73 | * |
50 | 74 | * @param nranks The total number of ranks participating in the shuffle. |
51 | 75 | * @param local_partitions The partition IDs local to the current rank. |
| 76 | + * @param finished_callback The callback to notify when a partition is finished |
| 77 | + * (optional). |
52 | 78 | */ |
53 | | - FinishCounter(Rank nranks, std::vector<PartID> const& local_partitions); |
| 79 | + FinishCounter( |
| 80 | + Rank nranks, |
| 81 | + std::vector<PartID> const& local_partitions, |
| 82 | + FinishedCallback&& finished_callback = nullptr |
| 83 | + ); |
54 | 84 |
|
55 | 85 | ~FinishCounter(); |
56 | 86 |
|
@@ -88,77 +118,6 @@ class FinishCounter { |
88 | 118 | */ |
89 | 119 | [[nodiscard]] bool all_finished() const; |
90 | 120 |
|
91 | | - /** |
92 | | - * @brief Returns whether a partition is finished (non-blocking). |
93 | | - * |
94 | | - * @param pid The partition ID to check. |
95 | | - * @return True if the partition is finished, otherwise False. |
96 | | - */ |
97 | | - [[nodiscard]] bool is_finished(PartID pid) const; |
98 | | - |
99 | | - /** |
100 | | - * @brief Callback function type called when a partition is finished. |
101 | | - * |
102 | | - * The callback receives the partition ID of the finished partition. |
103 | | - * |
104 | | - * @warning A callback must be fast and non-blocking and should not call any of the |
105 | | - * `wait*` methods. And be very careful if acquiring locks. Ideally it should be used |
106 | | - * to signal a separate thread to do the actual processing (eg. WaitHand). |
107 | | - * |
108 | | - * @note When a callback is registered, it will be identified by the |
109 | | - * FinishedCbId returned. So, if a callback needs to be preemptively canceled, |
110 | | - * the corresponding identifier needs to be provided. |
111 | | - * |
112 | | - * @note Every callback will be called as and when each partition is finished. If |
113 | | - * there were finished partitions before the callback was registered, the callback |
114 | | - * will be called for them immediately by the caller thread. Else, the callback will |
115 | | - * be called by the progress thread (Therefore, it will be called |
116 | | - * `n_local_partitions_` times in total). |
117 | | - * |
118 | | - * @note Caller needs to be careful when using both callbacks and wait* methods |
119 | | - * together. |
120 | | - */ |
121 | | - using FinishedCallback = std::function<void(PartID)>; |
122 | | - |
123 | | - /** |
124 | | - * @brief Type used to identify callbacks. |
125 | | - */ |
126 | | - using FinishedCbId = size_t; |
127 | | - |
128 | | - /** |
129 | | - * @brief Register a callback to be notified when any partition is finished. |
130 | | - * |
131 | | - * This function registers a callback that will be called when a partition is finished |
132 | | - * (and for all currently finished partitions). The callback receives partition IDs as |
133 | | - * they complete. If all partitions are already finished, the callback is executed |
134 | | - * immediately for all partitions and invalid_cb_id is returned. |
135 | | - * |
136 | | - * @param cb The callback to invoke when partitions are finished. |
137 | | - * |
138 | | - * @return A unique callback ID that can be used to cancel the callback, or |
139 | | - * invalid_cb_id if the callback was executed immediately. |
140 | | - */ |
141 | | - FinishedCbId register_finished_callback(FinishedCallback&& cb); |
142 | | - |
143 | | - /** |
144 | | - * @brief Special constant indicating an invalid or immediately-executed callback ID. |
145 | | - * |
146 | | - * This value is returned by register_finished_callback when the callback is executed |
147 | | - * immediately (e.g., when all partitions are already finished). |
148 | | - */ |
149 | | - static constexpr FinishedCbId invalid_cb_id = |
150 | | - std::numeric_limits<FinishedCbId>::max(); |
151 | | - |
152 | | - /** |
153 | | - * @brief Cancel a previously registered callback. |
154 | | - * |
155 | | - * This function removes a callback registered with register_finished_callback using |
156 | | - * its ID. It is safe to call this with invalid_cb_id or an already-cancelled ID. |
157 | | - * |
158 | | - * @param callback_id callback ID. |
159 | | - */ |
160 | | - void remove_finished_callback(FinishedCbId callback_id); |
161 | | - |
162 | 121 | /** |
163 | 122 | * @brief Returns the partition ID of a finished partition that hasn't been waited on |
164 | 123 | * (blocking). Optionally a timeout (in ms) can be provided. |
@@ -250,30 +209,32 @@ class FinishCounter { |
250 | 209 | // when all ranks has reported their goal that the goalpost is final. |
251 | 210 | std::unordered_map<PartID, PartitionInfo> goalposts_; |
252 | 211 |
|
253 | | - std::vector<PartID> finished_partitions_{}; ///< partition IDs of finished partitions |
| 212 | + // mutex to control access between the progress thread and the caller thread on shared |
| 213 | + // resources |
| 214 | + mutable std::mutex mutex_; // TODO: use a shared_mutex lock? |
254 | 215 |
|
255 | | - std::mutex finished_cbs_mutex_; ///< mutex to protect the finished_cbs_ and |
256 | | - ///< next_finished_cb_id_ |
| 216 | + ///@brief Handler to implement the wait* methods using callbacks |
| 217 | + struct WaitHandler { |
| 218 | + std::unordered_set<PartID> |
| 219 | + to_wait{}; ///< finished partitions available to wait on |
| 220 | + bool active{true}; |
| 221 | + std::condition_variable cv; |
| 222 | + std::mutex mutex; |
257 | 223 |
|
258 | | - struct CallbackContainer { |
259 | | - FinishedCbId cb_id; ///< callback ID to identify the callback |
| 224 | + ~WaitHandler(); |
260 | 225 |
|
261 | | - // index of the next partition that the callback is interested in. cb will |
262 | | - // called from next_pid_idx to end of finished_partitions_ |
263 | | - size_t next_pid_idx; |
| 226 | + // Callback to listen on the finished partitions |
| 227 | + void on_finished_cb(PartID pid); |
264 | 228 |
|
265 | | - FinishedCallback cb; |
266 | | - }; |
| 229 | + PartID wait_any(std::optional<std::chrono::milliseconds> timeout); |
267 | 230 |
|
268 | | - std::vector<CallbackContainer> finished_cbs_{}; |
269 | | - FinishedCbId next_finished_cb_id_{0}; ///< next callback ID to assign |
| 231 | + void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout); |
| 232 | + }; |
270 | 233 |
|
271 | | - // mutex to control access between the progress thread and the caller thread on shared |
272 | | - // resources |
273 | | - mutable std::mutex mutex_; // TODO: use a shared_mutex lock? |
| 234 | + WaitHandler wait_handler_{}; |
274 | 235 |
|
275 | | - class WaitHandler; ///< Handler to implement the wait* methods using callbacks |
276 | | - std::unique_ptr<WaitHandler> wait_handler_; |
| 236 | + FinishedCallback finished_callback_ = |
| 237 | + nullptr; ///< callback to notify when a partition is finished |
277 | 238 | }; |
278 | 239 |
|
279 | 240 | } // namespace detail |
|
0 commit comments