Skip to content

Commit 35d6207

Browse files
committed
moving to a single callback
Signed-off-by: niranda perera <niranda.perera@gmail.com>
1 parent 141b711 commit 35d6207

File tree

6 files changed

+234
-313
lines changed

6 files changed

+234
-313
lines changed

cpp/include/rapidsmpf/shuffler/finish_counter.hpp

Lines changed: 51 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
*/
55
#pragma once
66

7+
#include <array>
78
#include <chrono>
89
#include <condition_variable>
910
#include <functional>
1011
#include <mutex>
1112
#include <optional>
1213
#include <unordered_map>
13-
#include <unordered_set>
1414
#include <vector>
1515

1616
#include <rapidsmpf/communicator/communicator.hpp>
@@ -44,13 +44,43 @@ namespace detail {
4444
*/
4545
class FinishCounter {
4646
public:
47+
/**
48+
* @brief Callback function type called when a partition is finished.
49+
*
50+
* The callback receives the partition ID of the finished partition.
51+
*
52+
* @warning A callback must be fast and non-blocking and should not call any of the
53+
* `wait*` methods. And be very careful if acquiring locks. Ideally it should be used
54+
* to signal a separate thread to do the actual processing (eg. WaitHand).
55+
*
56+
* @note When a callback is registered, it will be identified by the
57+
* FinishedCbId returned. So, if a callback needs to be preemptively canceled,
58+
* the corresponding identifier needs to be provided.
59+
*
60+
* @note Every callback will be called as and when each partition is finished. If
61+
* there were finished partitions before the callback was registered, the callback
62+
* will be called for them immediately by the caller thread. Else, the callback will
63+
* be called by the progress thread (Therefore, it will be called
64+
* `n_local_partitions_` times in total).
65+
*
66+
* @note Caller needs to be careful when using both callbacks and wait* methods
67+
* together.
68+
*/
69+
using FinishedCallback = std::function<void(PartID)>;
70+
4771
/**
4872
* @brief Construct a finish counter.
4973
*
5074
* @param nranks The total number of ranks participating in the shuffle.
5175
* @param local_partitions The partition IDs local to the current rank.
76+
* @param finished_callback The callback to notify when a partition is finished
77+
* (optional).
5278
*/
53-
FinishCounter(Rank nranks, std::vector<PartID> const& local_partitions);
79+
FinishCounter(
80+
Rank nranks,
81+
std::vector<PartID> const& local_partitions,
82+
FinishedCallback&& finished_callback = nullptr
83+
);
5484

5585
~FinishCounter();
5686

@@ -88,77 +118,6 @@ class FinishCounter {
88118
*/
89119
[[nodiscard]] bool all_finished() const;
90120

91-
/**
92-
* @brief Returns whether a partition is finished (non-blocking).
93-
*
94-
* @param pid The partition ID to check.
95-
* @return True if the partition is finished, otherwise False.
96-
*/
97-
[[nodiscard]] bool is_finished(PartID pid) const;
98-
99-
/**
100-
* @brief Callback function type called when a partition is finished.
101-
*
102-
* The callback receives the partition ID of the finished partition.
103-
*
104-
* @warning A callback must be fast and non-blocking and should not call any of the
105-
* `wait*` methods. And be very careful if acquiring locks. Ideally it should be used
106-
* to signal a separate thread to do the actual processing (eg. WaitHand).
107-
*
108-
* @note When a callback is registered, it will be identified by the
109-
* FinishedCbId returned. So, if a callback needs to be preemptively canceled,
110-
* the corresponding identifier needs to be provided.
111-
*
112-
* @note Every callback will be called as and when each partition is finished. If
113-
* there were finished partitions before the callback was registered, the callback
114-
* will be called for them immediately by the caller thread. Else, the callback will
115-
* be called by the progress thread (Therefore, it will be called
116-
* `n_local_partitions_` times in total).
117-
*
118-
* @note Caller needs to be careful when using both callbacks and wait* methods
119-
* together.
120-
*/
121-
using FinishedCallback = std::function<void(PartID)>;
122-
123-
/**
124-
* @brief Type used to identify callbacks.
125-
*/
126-
using FinishedCbId = size_t;
127-
128-
/**
129-
* @brief Register a callback to be notified when any partition is finished.
130-
*
131-
* This function registers a callback that will be called when a partition is finished
132-
* (and for all currently finished partitions). The callback receives partition IDs as
133-
* they complete. If all partitions are already finished, the callback is executed
134-
* immediately for all partitions and invalid_cb_id is returned.
135-
*
136-
* @param cb The callback to invoke when partitions are finished.
137-
*
138-
* @return A unique callback ID that can be used to cancel the callback, or
139-
* invalid_cb_id if the callback was executed immediately.
140-
*/
141-
FinishedCbId register_finished_callback(FinishedCallback&& cb);
142-
143-
/**
144-
* @brief Special constant indicating an invalid or immediately-executed callback ID.
145-
*
146-
* This value is returned by register_finished_callback when the callback is executed
147-
* immediately (e.g., when all partitions are already finished).
148-
*/
149-
static constexpr FinishedCbId invalid_cb_id =
150-
std::numeric_limits<FinishedCbId>::max();
151-
152-
/**
153-
* @brief Cancel a previously registered callback.
154-
*
155-
* This function removes a callback registered with register_finished_callback using
156-
* its ID. It is safe to call this with invalid_cb_id or an already-cancelled ID.
157-
*
158-
* @param callback_id callback ID.
159-
*/
160-
void remove_finished_callback(FinishedCbId callback_id);
161-
162121
/**
163122
* @brief Returns the partition ID of a finished partition that hasn't been waited on
164123
* (blocking). Optionally a timeout (in ms) can be provided.
@@ -250,30 +209,32 @@ class FinishCounter {
250209
// when all ranks has reported their goal that the goalpost is final.
251210
std::unordered_map<PartID, PartitionInfo> goalposts_;
252211

253-
std::vector<PartID> finished_partitions_{}; ///< partition IDs of finished partitions
212+
// mutex to control access between the progress thread and the caller thread on shared
213+
// resources
214+
mutable std::mutex mutex_; // TODO: use a shared_mutex lock?
254215

255-
std::mutex finished_cbs_mutex_; ///< mutex to protect the finished_cbs_ and
256-
///< next_finished_cb_id_
216+
///@brief Handler to implement the wait* methods using callbacks
217+
struct WaitHandler {
218+
std::unordered_set<PartID>
219+
to_wait{}; ///< finished partitions available to wait on
220+
bool active{true};
221+
std::condition_variable cv;
222+
std::mutex mutex;
257223

258-
struct CallbackContainer {
259-
FinishedCbId cb_id; ///< callback ID to identify the callback
224+
~WaitHandler();
260225

261-
// index of the next partition that the callback is interested in. cb will
262-
// called from next_pid_idx to end of finished_partitions_
263-
size_t next_pid_idx;
226+
// Callback to listen on the finished partitions
227+
void on_finished_cb(PartID pid);
264228

265-
FinishedCallback cb;
266-
};
229+
PartID wait_any(std::optional<std::chrono::milliseconds> timeout);
267230

268-
std::vector<CallbackContainer> finished_cbs_{};
269-
FinishedCbId next_finished_cb_id_{0}; ///< next callback ID to assign
231+
void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout);
232+
};
270233

271-
// mutex to control access between the progress thread and the caller thread on shared
272-
// resources
273-
mutable std::mutex mutex_; // TODO: use a shared_mutex lock?
234+
WaitHandler wait_handler_{};
274235

275-
class WaitHandler; ///< Handler to implement the wait* methods using callbacks
276-
std::unique_ptr<WaitHandler> wait_handler_;
236+
FinishedCallback finished_callback_ =
237+
nullptr; ///< callback to notify when a partition is finished
277238
};
278239

279240
} // namespace detail

cpp/include/rapidsmpf/shuffler/shuffler.hpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class Shuffler {
7777
PartitionOwner partition_owner
7878
);
7979

80+
/// @copydoc detail::FinishCounter::FinishedCallback
81+
using FinishedCallback = detail::FinishCounter::FinishedCallback;
82+
8083
/**
8184
* @brief Construct a new shuffler for a single shuffle.
8285
*
@@ -87,6 +90,7 @@ class Shuffler {
8790
* @param total_num_partitions Total number of partitions in the shuffle.
8891
* @param stream The CUDA stream for memory operations.
8992
* @param br Buffer resource used to allocate temporary and the shuffle result.
93+
* @param finished_callback Callback to notify when a partition is finished.
9094
* @param statistics The statistics instance to use (disabled by default).
9195
* @param partition_owner Function to determine partition ownership.
9296
*/
@@ -97,10 +101,46 @@ class Shuffler {
97101
PartID total_num_partitions,
98102
rmm::cuda_stream_view stream,
99103
BufferResource* br,
104+
FinishedCallback&& finished_callback,
100105
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
101106
PartitionOwner partition_owner = round_robin
102107
);
103108

109+
/**
110+
* @brief Construct a new shuffler for a single shuffle.
111+
*
112+
* @param comm The communicator to use.
113+
* @param progress_thread The progress thread to use.
114+
* @param op_id The operation ID of the shuffle. This ID is unique for this operation,
115+
* and should not be reused until all nodes has called `Shuffler::shutdown()`.
116+
* @param total_num_partitions Total number of partitions in the shuffle.
117+
* @param stream The CUDA stream for memory operations.
118+
* @param br Buffer resource used to allocate temporary and the shuffle result.
119+
* @param statistics The statistics instance to use (disabled by default).
120+
* @param partition_owner Function to determine partition ownership.
121+
*/
122+
Shuffler(
123+
std::shared_ptr<Communicator> comm,
124+
std::shared_ptr<ProgressThread> progress_thread,
125+
OpID op_id,
126+
PartID total_num_partitions,
127+
rmm::cuda_stream_view stream,
128+
BufferResource* br,
129+
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
130+
PartitionOwner partition_owner = round_robin
131+
)
132+
: Shuffler(
133+
comm,
134+
progress_thread,
135+
op_id,
136+
total_num_partitions,
137+
stream,
138+
br,
139+
nullptr,
140+
statistics,
141+
partition_owner
142+
) {}
143+
104144
~Shuffler();
105145

106146
/**
@@ -170,18 +210,6 @@ class Shuffler {
170210
*/
171211
[[nodiscard]] bool is_finished(PartID pid) const;
172212

173-
/// @copydoc detail::FinishCounter::FinishedCallback
174-
using FinishedCallback = detail::FinishCounter::FinishedCallback;
175-
176-
/// @copydoc detail::FinishCounter::FinishedCbId
177-
using FinishedCbId = detail::FinishCounter::FinishedCbId;
178-
179-
/// @copydoc detail::FinishCounter::register_finished_callback
180-
FinishedCbId register_finished_callback(FinishedCallback&& cb);
181-
182-
/// @copydoc detail::FinishCounter::remove_finished_callback
183-
void remove_finished_callback(FinishedCbId callback_id);
184-
185213
/**
186214
* @brief Wait for any partition to finish.
187215
*

0 commit comments

Comments
 (0)