1+ // Licensed to the Apache Software Foundation (ASF) under one
2+ // or more contributor license agreements. See the NOTICE file
3+ // distributed with this work for additional information
4+ // regarding copyright ownership. The ASF licenses this file
5+ // to you under the Apache License, Version 2.0 (the
6+ // "License"); you may not use this file except in compliance
7+ // with the License. You may obtain a copy of the License at
8+ //
9+ // http://www.apache.org/licenses/LICENSE-2.0
10+ //
11+ // Unless required by applicable law or agreed to in writing,
12+ // software distributed under the License is distributed on an
13+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ // KIND, either express or implied. See the License for the
15+ // specific language governing permissions and limitations
16+ // under the License.
17+
18+ #pragma once
19+
20+ #include < bthread/countdown_event.h>
21+ #include < cpp/sync_point.h>
22+ #include < fmt/core.h>
23+ #include < gen_cpp/cloud.pb.h>
24+ #include < glog/logging.h>
25+
26+ #include < chrono>
27+ #include < future>
28+ #include < string>
29+
30+ #include " common/defer.h"
31+ #include " common/simple_thread_pool.h"
32+
33+ namespace doris ::cloud {
34+
35+ template <typename T>
36+ class SyncExecutor {
37+ public:
38+ SyncExecutor (
39+ std::shared_ptr<SimpleThreadPool> pool, std::string name_tag,
40+ std::function<bool (const T&)> cancel = [](const T& /* */ ) { return false ; })
41+ : _pool(std::move(pool)), _cancel(std::move(cancel)), _name_tag(std::move(name_tag)) {}
42+ auto add (std::function<T()> callback) -> SyncExecutor<T>& {
43+ auto task = std::make_unique<Task>(std::move (callback), _cancel, _count);
44+ _count.add_count ();
45+ // The actual task logic would be wrapped by one promise and passed to the threadpool.
46+ // The result would be returned by the future once the task is finished.
47+ // Or the task would be invalid if the whole task is cancelled.
48+ int r = _pool->submit ([this , t = task.get ()]() { (*t)(_stop_token); });
49+ CHECK (r == 0 );
50+ _res.emplace_back (std::move (task));
51+ return *this ;
52+ }
53+ std::vector<T> when_all (bool * finished) {
54+ DORIS_CLOUD_DEFER {
55+ _reset ();
56+ };
57+ timespec current_time;
58+ auto current_time_second = time (nullptr );
59+ current_time.tv_sec = current_time_second + 300 ;
60+ current_time.tv_nsec = 0 ;
61+ // Wait for all tasks to complete
62+ while (0 != _count.timed_wait (current_time)) {
63+ current_time.tv_sec += 300 ;
64+ LOG (WARNING) << _name_tag << " has already taken 5 min, cost: "
65+ << time (nullptr ) - current_time_second << " seconds" ;
66+ }
67+ *finished = !_stop_token;
68+ std::vector<T> res;
69+ res.reserve (_res.size ());
70+ for (auto & task : _res) {
71+ if (!task->valid ()) {
72+ *finished = false ;
73+ return res;
74+ }
75+ size_t max_wait_ms = 10000 ;
76+ TEST_SYNC_POINT_CALLBACK (" SyncExecutor::when_all.set_wait_time" , &max_wait_ms);
77+ // _count.timed_wait has already ensured that all tasks are completed.
78+ // The 10 seconds here is just waiting for the task results to be returned,
79+ // so 10 seconds is more than enough.
80+ auto status = task->wait_for (max_wait_ms);
81+ if (status == std::future_status::ready) {
82+ res.emplace_back (task->get ());
83+ } else {
84+ *finished = false ;
85+ LOG (WARNING) << _name_tag << " task timed out after 10 seconds" ;
86+ return res;
87+ }
88+ }
89+ return res;
90+ }
91+
92+ private:
93+ void _reset () {
94+ _count.reset (0 );
95+ _res.clear ();
96+ _stop_token = false ;
97+ }
98+
99+ private:
100+ class Task {
101+ public:
102+ Task (std::function<T()> callback, std::function<bool (const T&)> cancel,
103+ bthread::CountdownEvent& count)
104+ : _callback(std::move(callback)),
105+ _cancel (std::move(cancel)),
106+ _count(count),
107+ _fut(_pro.get_future()) {}
108+ void operator ()(std::atomic_bool& stop_token) {
109+ DORIS_CLOUD_DEFER {
110+ _count.signal ();
111+ };
112+ if (stop_token) {
113+ _valid = false ;
114+ return ;
115+ }
116+ T t = _callback ();
117+ // We'll return this task result to user even if this task return error
118+ // So we don't set _valid to false here
119+ if (_cancel (t)) {
120+ stop_token = true ;
121+ }
122+ _pro.set_value (std::move (t));
123+ }
124+ std::future_status wait_for (size_t milliseconds) {
125+ return _fut.wait_for (std::chrono::milliseconds (milliseconds));
126+ }
127+ bool valid () { return _valid; }
128+ T get () { return _fut.get (); }
129+
130+ private:
131+ // It's guarantted that the valid function can only be called inside SyncExecutor's `when_all()` function
132+ // and only be called when the _count.timed_wait function returned. So there would be no data race for
133+ // _valid then it doesn't need to be one atomic bool.
134+ bool _valid = true ;
135+ std::function<T()> _callback;
136+ std::function<bool (const T&)> _cancel;
137+ std::promise<T> _pro;
138+ bthread::CountdownEvent& _count;
139+ std::future<T> _fut;
140+ };
141+ std::vector<std::unique_ptr<Task>> _res;
142+ // use CountdownEvent to do periodically log using CountdownEvent::time_wait()
143+ bthread::CountdownEvent _count {0 };
144+ std::atomic_bool _stop_token {false };
145+ std::shared_ptr<SimpleThreadPool> _pool;
146+ std::function<bool (const T&)> _cancel;
147+ std::string _name_tag;
148+ };
149+ } // namespace doris::cloud
0 commit comments