Skip to content

Commit c687fab

Browse files
committed
[WIP]
Signed-off-by: Michał Zientkiewicz <[email protected]>
1 parent 0213656 commit c687fab

File tree

2 files changed

+81
-22
lines changed

2 files changed

+81
-22
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "dali/pipeline/util/new_thread_pool.h"
16+
17+
namespace dali {
18+
namespace experimental {
19+
20+
} // namespace experimental
21+
} // namespace dali
22+

dali/pipeline/util/new_thread_pool.h

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <string>
2424
#include <utility>
2525
#include <vector>
26+
#include <mutex>
27+
#include <condition_variable>
2628
#include "dali/core/call_at_exit.h"
2729
#include "dali/core/error_handling.h"
2830
#include "dali/core/multi_error.h"
@@ -53,7 +55,7 @@ class Job {
5355
std::enable_if_t<std::is_convertible_v<Runnable, std::function<void(int)>>>
5456
AddTask(Runnable &&runnable, priority_t priority = {}) {
5557
if (started_)
56-
throw std::logic_error("This has already been started - cannot add more tasks to it");
58+
throw std::logic_error("This job has already been started - cannot add more tasks to it");
5759
auto it = tasks_.emplace(priority, Task());
5860
try {
5961
it->second.func = [this, task = &it->second, f = std::move(runnable)](int tid) {
@@ -150,19 +152,21 @@ class ThreadPoolBase {
150152
Stop();
151153
}
152154

153-
void AddTask(TaskFunc f) {
154-
{
155-
std::lock_guard<std::mutex> g(mtx_);
156-
if (stop_requested_)
157-
throw std::logic_error("The thread pool is stopped and no longer accepts new tasks.");
158-
tasks_.push(std::move(f));
159-
}
160-
cv_.notify_one();
155+
void AddTask(TaskFunc f);
156+
157+
static ThreadPoolBase *this_thread_pool() {
158+
return this_thread_pool_;
159+
}
160+
161+
static int this_thread_idx() {
162+
return thread_idx_;
161163
}
162164

163165
protected:
164-
virtual void OnThreadStart(int thread_idx) {}
165-
virtual void OnThreadStop(int thread_idx) {}
166+
virtual void OnThreadStart(int thread_idx) noexcept {}
167+
virtual void OnThreadStop(int thread_idx) noexcept {}
168+
169+
friend class Job;
166170

167171
void Stop() {
168172
{
@@ -176,33 +180,66 @@ class ThreadPoolBase {
176180

177181
{
178182
std::lock_guard<std::mutex> g(mtx_);
183+
for (auto &task : tasks_) {
184+
}
179185
threads_.clear();
180186
}
181187
}
182188

183-
void Run(int index) noexcept {
184-
OnThreadStart(index);
185-
detail::CallAtExit([&]() { OnThreadStop(index); });
189+
template <typename Condition>
190+
bool WaitOrRunTasks(std::condition_variable cv, Condition &&condition) {
191+
assert(this_thread_pool() == this);
186192
std::unique_lock lock(mtx_);
187-
while (!stop_requested_) {
193+
do {
188194
cv_.wait(lock, [&]() { return stop_requested_ || !tasks_.empty(); });
189-
if (stop_requested_)
190-
break;
191-
TaskFunc t = std::move(tasks_.front());
192-
tasks_.pop();
193-
lock.unlock();
194-
t(index);
195-
lock.lock();
196195
}
196+
197197
}
198198

199+
static thread_local ThreadPoolBase *this_thread_pool_;
200+
static int thread_idx_;
201+
202+
void Run(int index) noexcept;
203+
199204
std::mutex mtx_;
200205
std::condition_variable cv_;
201206
bool stop_requested_ = false;
202207
std::queue<TaskFunc> tasks_;
203208
std::vector<std::thread> threads_;
204209
};
205210

211+
212+
thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr;
213+
thread_local int ThreadPoolBase::this_thread_index_ = -1;;
214+
215+
inline void ThreadPoolBase::AddTask(TaskFunc f) {
216+
{
217+
std::lock_guard<std::mutex> g(mtx_);
218+
if (stop_requested_)
219+
throw std::logic_error("The thread pool is stopped and no longer accepts new tasks.");
220+
tasks_.push(std::move(f));
221+
}
222+
cv_.notify_one();
223+
}
224+
225+
inline void ThreadPoolBase::Run(int index) noexcept {
226+
ThreadPoolBase *this_thread_pool_ = this;
227+
this_thread_idx_ = index;
228+
OnThreadStart(index);
229+
detail::CallAtExit([&]() { OnThreadStop(index); });
230+
std::unique_lock lock(mtx_);
231+
while (!stop_requested_) {
232+
cv_.wait(lock, [&]() { return stop_requested_ || !tasks_.empty(); });
233+
if (stop_requested_)
234+
break;
235+
TaskFunc t = std::move(tasks_.front());
236+
tasks_.pop();
237+
lock.unlock();
238+
t(index);
239+
lock.lock();
240+
}
241+
}
242+
206243
} // namespace experimental
207244
} // namespace dali
208245

0 commit comments

Comments
 (0)