@@ -47,50 +47,13 @@ void wait_for_if_timeout_else_wait(
4747
4848} // namespace
4949
50- FinishCounter::WaitHandler::~WaitHandler () {
51- {
52- std::lock_guard lock (mutex);
53- active = false ;
54- }
55- cv.notify_all (); // notify any waiting threads
56- }
57-
58- void FinishCounter::WaitHandler::on_finished_cb (PartID pid) {
59- {
60- std::lock_guard lock (mutex);
61- to_wait.emplace (pid);
62- }
63- cv.notify_all ();
64- }
65-
66- PartID FinishCounter::WaitHandler::wait_any (
67- std::optional<std::chrono::milliseconds> timeout
68- ) {
69- std::unique_lock lock (mutex);
70- wait_for_if_timeout_else_wait (lock, cv, timeout, [&] {
71- return !active || !to_wait.empty ();
72- });
73- RAPIDSMPF_EXPECTS (active, " wait callback already finished" , std::runtime_error);
74- return to_wait.extract (to_wait.begin ()).value ();
75- }
76-
77- void FinishCounter::WaitHandler::wait_on (
78- PartID pid, std::optional<std::chrono::milliseconds> timeout
79- ) {
80- std::unique_lock lock (mutex);
81- wait_for_if_timeout_else_wait (lock, cv, timeout, [&] {
82- return !active || to_wait.contains (pid);
83- });
84- RAPIDSMPF_EXPECTS (active, " wait callback already finished" , std::runtime_error);
85- to_wait.erase (pid);
86- }
87-
8850FinishCounter::FinishCounter (
8951 Rank nranks,
9052 std::vector<PartID> const & local_partitions,
9153 FinishedCallback&& finished_callback
9254)
9355 : nranks_{nranks},
56+ n_unfinished_partitions_{static_cast <PartID>(local_partitions.size ())},
9457 finished_callback_{std::forward<FinishedCallback>(finished_callback)} {
9558 // Initially, none of the partitions are ready to wait on.
9659 goalposts_.reserve (local_partitions.size ());
@@ -103,40 +66,75 @@ FinishCounter::~FinishCounter() = default;
10366
10467bool FinishCounter::all_finished () const {
10568 std::unique_lock<std::mutex> lock (mutex_);
106- return goalposts_.empty ();
69+ // we can not use the goalposts.empty() because its being consumed by wait* methods
70+ return n_unfinished_partitions_ == 0 || goalposts_.empty ();
10771}
10872
10973void FinishCounter::move_goalpost (PartID pid, ChunkID nchunks) {
11074 std::unique_lock<std::mutex> lock (mutex_);
111- auto & p_info = goalposts_. at ( pid) ;
75+ auto & p_info = goalposts_[ pid] ;
11276 p_info.move_goalpost (nchunks, nranks_);
11377}
11478
11579void FinishCounter::add_finished_chunk (PartID pid) {
11680 std::unique_lock<std::mutex> lock (mutex_);
117- auto & p_info = goalposts_. at ( pid) ;
81+ auto & p_info = goalposts_[ pid] ;
11882
11983 p_info.add_finished_chunk (nranks_);
12084
12185 if (p_info.is_finished (nranks_)) {
122- std::ignore = goalposts_.erase (pid);
86+ RAPIDSMPF_EXPECTS (
87+ n_unfinished_partitions_ > 0 , " all partitions have been finished"
88+ ); // TODO: use a debug flag
89+ n_unfinished_partitions_--;
12390 lock.unlock ();
12491
125- wait_handler_.on_finished_cb (pid);
126- if (finished_callback_) {
92+ wait_cv_.notify_all (); // notify any waiting threads
93+
94+ if (finished_callback_) { // notify the callback
12795 finished_callback_ (pid);
12896 }
12997 }
13098}
13199
132100PartID FinishCounter::wait_any (std::optional<std::chrono::milliseconds> timeout) {
133- return wait_handler_.wait_any (std::move (timeout));
101+ PartID finished_key{std::numeric_limits<PartID>::max ()};
102+
103+ std::unique_lock<std::mutex> lock (mutex_);
104+ wait_for_if_timeout_else_wait (lock, wait_cv_, timeout, [&] {
105+ return goalposts_.empty ()
106+ || std::ranges::any_of (goalposts_, [&](auto const & item) {
107+ auto done = item.second .is_finished (nranks_);
108+ if (done) {
109+ finished_key = item.first ;
110+ }
111+ return done;
112+ });
113+ });
114+
115+ RAPIDSMPF_EXPECTS (
116+ finished_key != std::numeric_limits<PartID>::max (),
117+ " no more partitions to wait on" ,
118+ std::out_of_range
119+ );
120+
121+ // We extract the partition to avoid returning the same partition twice.
122+ goalposts_.erase (finished_key);
123+ return finished_key;
134124}
135125
136126void FinishCounter::wait_on (
137127 PartID pid, std::optional<std::chrono::milliseconds> timeout
138128) {
139- wait_handler_.wait_on (pid, std::move (timeout));
129+ std::unique_lock<std::mutex> lock (mutex_);
130+ wait_for_if_timeout_else_wait (lock, wait_cv_, timeout, [&] {
131+ auto it = goalposts_.find (pid);
132+ RAPIDSMPF_EXPECTS (
133+ it != goalposts_.end (), " PartID has already been extracted" , std::out_of_range
134+ );
135+ return it->second .is_finished (nranks_);
136+ });
137+ goalposts_.erase (pid);
140138}
141139
142140std::string detail::FinishCounter::str () const {
0 commit comments