Skip to content

Commit 438dc6f

Browse files
committed
reverting wait handler
Signed-off-by: niranda perera <niranda.perera@gmail.com>
1 parent 28b810f commit 438dc6f

File tree

3 files changed

+60
-73
lines changed

3 files changed

+60
-73
lines changed

cpp/include/rapidsmpf/shuffler/finish_counter.hpp

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ class FinishCounter {
120120
*
121121
* @return The partition ID of a finished partition.
122122
*
123-
* @throws std::runtime_error If all partitions have already been waited on or if
124-
* timeout was set and no partitions have been finished by the expiration.
123+
* @throws std::out_of_range If all partitions have already been waited on.
124+
* @throws std::runtime_error If timeout was set and no partitions have been finished
125+
* by the expiration.
125126
*/
126127
PartID wait_any(std::optional<std::chrono::milliseconds> timeout = {});
127128

@@ -136,8 +137,9 @@ class FinishCounter {
136137
* @param pid The desired partition ID.
137138
* @param timeout Optional timeout (ms) to wait.
138139
*
139-
* @throws std::runtime_error If the desired partition is unavailable or if timeout
140-
* was set and requested partition has not been finished by the expiration.
140+
* @throws std::out_of_range If the desired partition is unavailable.
141+
* @throws std::runtime_error If timeout was set and requested partition has been
142+
* finished by the expiration.
141143
*/
142144
void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
143145

@@ -149,6 +151,9 @@ class FinishCounter {
149151

150152
private:
151153
Rank const nranks_;
154+
PartID
155+
n_unfinished_partitions_; ///< aux counter to track the number of unfinished
156+
///< partitions (without using the goalposts.empty())
152157

153158
/// @brief Information about a local partition.
154159
struct PartitionInfo {
@@ -199,29 +204,8 @@ class FinishCounter {
199204
// when all ranks has reported their goal that the goalpost is final.
200205
std::unordered_map<PartID, PartitionInfo> goalposts_;
201206

202-
// mutex to control access between the progress thread and the caller thread on shared
203-
// resources
204207
mutable std::mutex mutex_; // TODO: use a shared_mutex lock?
205-
206-
///@brief Handler to implement the wait* methods using callbacks
207-
struct WaitHandler {
208-
std::unordered_set<PartID>
209-
to_wait{}; ///< finished partitions available to wait on
210-
bool active{true};
211-
std::condition_variable cv;
212-
std::mutex mutex;
213-
214-
~WaitHandler();
215-
216-
// Callback to listen on the finished partitions
217-
void on_finished_cb(PartID pid);
218-
219-
PartID wait_any(std::optional<std::chrono::milliseconds> timeout);
220-
221-
void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout);
222-
};
223-
224-
WaitHandler wait_handler_{};
208+
mutable std::condition_variable wait_cv_;
225209

226210
FinishedCallback finished_callback_ =
227211
nullptr; ///< callback to notify when a partition is finished

cpp/src/shuffler/finish_counter.cpp

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,50 +47,13 @@ void wait_for_if_timeout_else_wait(
4747

4848
} // namespace
4949

50-
FinishCounter::WaitHandler::~WaitHandler() {
51-
{
52-
std::lock_guard lock(mutex);
53-
active = false;
54-
}
55-
cv.notify_all(); // notify any waiting threads
56-
}
57-
58-
void FinishCounter::WaitHandler::on_finished_cb(PartID pid) {
59-
{
60-
std::lock_guard lock(mutex);
61-
to_wait.emplace(pid);
62-
}
63-
cv.notify_all();
64-
}
65-
66-
PartID FinishCounter::WaitHandler::wait_any(
67-
std::optional<std::chrono::milliseconds> timeout
68-
) {
69-
std::unique_lock lock(mutex);
70-
wait_for_if_timeout_else_wait(lock, cv, timeout, [&] {
71-
return !active || !to_wait.empty();
72-
});
73-
RAPIDSMPF_EXPECTS(active, "wait callback already finished", std::runtime_error);
74-
return to_wait.extract(to_wait.begin()).value();
75-
}
76-
77-
void FinishCounter::WaitHandler::wait_on(
78-
PartID pid, std::optional<std::chrono::milliseconds> timeout
79-
) {
80-
std::unique_lock lock(mutex);
81-
wait_for_if_timeout_else_wait(lock, cv, timeout, [&] {
82-
return !active || to_wait.contains(pid);
83-
});
84-
RAPIDSMPF_EXPECTS(active, "wait callback already finished", std::runtime_error);
85-
to_wait.erase(pid);
86-
}
87-
8850
FinishCounter::FinishCounter(
8951
Rank nranks,
9052
std::vector<PartID> const& local_partitions,
9153
FinishedCallback&& finished_callback
9254
)
9355
: nranks_{nranks},
56+
n_unfinished_partitions_{static_cast<PartID>(local_partitions.size())},
9457
finished_callback_{std::forward<FinishedCallback>(finished_callback)} {
9558
// Initially, none of the partitions are ready to wait on.
9659
goalposts_.reserve(local_partitions.size());
@@ -103,40 +66,75 @@ FinishCounter::~FinishCounter() = default;
10366

10467
bool FinishCounter::all_finished() const {
10568
std::unique_lock<std::mutex> lock(mutex_);
106-
return goalposts_.empty();
69+
// we can not use the goalposts.empty() because its being consumed by wait* methods
70+
return n_unfinished_partitions_ == 0 || goalposts_.empty();
10771
}
10872

10973
void FinishCounter::move_goalpost(PartID pid, ChunkID nchunks) {
11074
std::unique_lock<std::mutex> lock(mutex_);
111-
auto& p_info = goalposts_.at(pid);
75+
auto& p_info = goalposts_[pid];
11276
p_info.move_goalpost(nchunks, nranks_);
11377
}
11478

11579
void FinishCounter::add_finished_chunk(PartID pid) {
11680
std::unique_lock<std::mutex> lock(mutex_);
117-
auto& p_info = goalposts_.at(pid);
81+
auto& p_info = goalposts_[pid];
11882

11983
p_info.add_finished_chunk(nranks_);
12084

12185
if (p_info.is_finished(nranks_)) {
122-
std::ignore = goalposts_.erase(pid);
86+
RAPIDSMPF_EXPECTS(
87+
n_unfinished_partitions_ > 0, "all partitions have been finished"
88+
); // TODO: use a debug flag
89+
n_unfinished_partitions_--;
12390
lock.unlock();
12491

125-
wait_handler_.on_finished_cb(pid);
126-
if (finished_callback_) {
92+
wait_cv_.notify_all(); // notify any waiting threads
93+
94+
if (finished_callback_) { // notify the callback
12795
finished_callback_(pid);
12896
}
12997
}
13098
}
13199

132100
PartID FinishCounter::wait_any(std::optional<std::chrono::milliseconds> timeout) {
133-
return wait_handler_.wait_any(std::move(timeout));
101+
PartID finished_key{std::numeric_limits<PartID>::max()};
102+
103+
std::unique_lock<std::mutex> lock(mutex_);
104+
wait_for_if_timeout_else_wait(lock, wait_cv_, timeout, [&] {
105+
return goalposts_.empty()
106+
|| std::ranges::any_of(goalposts_, [&](auto const& item) {
107+
auto done = item.second.is_finished(nranks_);
108+
if (done) {
109+
finished_key = item.first;
110+
}
111+
return done;
112+
});
113+
});
114+
115+
RAPIDSMPF_EXPECTS(
116+
finished_key != std::numeric_limits<PartID>::max(),
117+
"no more partitions to wait on",
118+
std::out_of_range
119+
);
120+
121+
// We extract the partition to avoid returning the same partition twice.
122+
goalposts_.erase(finished_key);
123+
return finished_key;
134124
}
135125

136126
void FinishCounter::wait_on(
137127
PartID pid, std::optional<std::chrono::milliseconds> timeout
138128
) {
139-
wait_handler_.wait_on(pid, std::move(timeout));
129+
std::unique_lock<std::mutex> lock(mutex_);
130+
wait_for_if_timeout_else_wait(lock, wait_cv_, timeout, [&] {
131+
auto it = goalposts_.find(pid);
132+
RAPIDSMPF_EXPECTS(
133+
it != goalposts_.end(), "PartID has already been extracted", std::out_of_range
134+
);
135+
return it->second.is_finished(nranks_);
136+
});
137+
goalposts_.erase(pid);
140138
}
141139

142140
std::string detail::FinishCounter::str() const {

cpp/tests/test_shuffler.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,12 @@ INSTANTIATE_TEST_SUITE_P(
226226
),
227227
testing::Values(1, 2, 5, 10), // total_num_partitions
228228
testing::Values(1, 9, 100, 100'000) // total_num_rows
229-
)
229+
),
230+
[](const testing::TestParamInfo<MemoryAvailable_NumPartition::ParamType>& info) {
231+
return std::to_string(info.index) + "__nparts_"
232+
+ std::to_string(std::get<1>(info.param)) + "__nrows_"
233+
+ std::to_string(std::get<2>(info.param));
234+
}
230235
);
231236

232237
// both insert and insert_finished ungrouped

0 commit comments

Comments
 (0)