Skip to content

Commit c5cb79c

Browse files
committed
Add incremental job. Validate in nvimgcodec.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
1 parent ced7133 commit c5cb79c

File tree

6 files changed

+481
-53
lines changed

6 files changed

+481
-53
lines changed

dali/core/exec/thread_pool_base.cc

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
namespace dali {
1919

2020
Job::~Job() noexcept(false) {
21-
if (!tasks_.empty() && !waited_for_) {
22-
std::lock_guard<std::mutex> g(mtx_);
23-
if (!tasks_.empty() && !waited_for_) {
24-
throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for.");
25-
}
21+
if (!tasks_.empty() && !waited_for_) {
22+
throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for.");
2623
}
2724
}
2825

@@ -33,9 +30,8 @@ void Job::Wait() {
3330
if (waited_for_)
3431
throw std::logic_error("This job has already been waited for.");
3532

36-
auto ready = [&]() { return num_pending_tasks_ == 0; };
37-
3833
if (ThreadPoolBase::this_thread_pool() != nullptr) {
34+
auto ready = [&]() { return num_pending_tasks_ == 0; };
3935
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
4036
waited_for_ = true;
4137
if (!result)
@@ -87,26 +83,90 @@ void Job::Scrap() {
8783

8884
///////////////////////////////////////////////////////////////////////////
8985

86+
IncrementalJob::~IncrementalJob() noexcept(false) {
87+
if (!tasks_.empty() && !waited_for_) {
88+
throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for.");
89+
}
90+
}
91+
92+
void IncrementalJob::Wait() {
93+
if (!executor_)
94+
throw std::logic_error("This job hasn't been run - cannot wait for it.");
95+
96+
if (waited_for_)
97+
throw std::logic_error("This job has already been waited for.");
98+
99+
if (ThreadPoolBase::this_thread_pool() != nullptr) {
100+
auto ready = [&]() { return num_pending_tasks_ == 0; };
101+
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
102+
waited_for_ = true;
103+
if (!result)
104+
throw std::logic_error("The thread pool was stopped");
105+
} else {
106+
int old = num_pending_tasks_.load(std::memory_order_acquire);
107+
while (old > 0) {
108+
num_pending_tasks_.wait(old, std::memory_order_acquire);
109+
old = num_pending_tasks_.load(std::memory_order_acquire);
110+
}
111+
waited_for_ = true;
112+
}
113+
114+
// note - this vector is not allocated unless there were exceptions thrown
115+
std::vector<std::exception_ptr> errors;
116+
for (auto &x : tasks_) {
117+
if (x.error)
118+
errors.push_back(std::move(x.error));
119+
}
120+
if (errors.size() == 1)
121+
std::rethrow_exception(errors[0]);
122+
else if (errors.size() > 1)
123+
throw MultipleErrors(std::move(errors));
124+
}
125+
126+
void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
127+
if (executor_ && executor_ != &tp)
128+
throw std::logic_error("This job is already running in a different executor.");
129+
executor_ = &tp;
130+
{
131+
auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin();
132+
auto batch = tp.BulkAdd();
133+
for (; it != tasks_.end(); ++it) {
134+
batch.Add(std::move(it->func));
135+
last_task_run_ = it;
136+
}
137+
}
138+
if (wait && !tasks_.empty())
139+
Wait();
140+
}
141+
142+
void IncrementalJob::Scrap() {
143+
if (executor_)
144+
throw std::logic_error("Cannot scrap a job that has already been started");
145+
tasks_.clear();
146+
}
147+
148+
///////////////////////////////////////////////////////////////////////////
149+
90150
thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr;
91151
thread_local int ThreadPoolBase::this_thread_idx_ = -1;;
92152

93-
void ThreadPoolBase::Init(int num_threads) {
153+
void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn> &on_thread_start) {
94154
if (shutdown_pending_)
95155
throw std::logic_error("The thread pool is being shut down.");
96156
std::lock_guard<std::mutex> g(mtx_);
97157
if (!threads_.empty())
98158
throw std::logic_error("The thread pool is already started!");
99159
threads_.reserve(num_threads);
100160
for (int i = 0; i < num_threads; i++)
101-
threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i));
161+
threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i, on_thread_start));
102162
}
103163

104-
void ThreadPoolBase::Shutdown() {
105-
if (shutdown_pending_)
164+
void ThreadPoolBase::Shutdown(bool join) {
165+
if (shutdown_pending_ && !join)
106166
return;
107167
{
108168
std::lock_guard<std::mutex> g(mtx_);
109-
if (shutdown_pending_)
169+
if (shutdown_pending_ && !join)
110170
return;
111171
shutdown_pending_ = true;
112172
cv_.notify_all();
@@ -132,11 +192,14 @@ void ThreadPoolBase::AddTask(TaskFunc &&f) {
132192
cv_.notify_one();
133193
}
134194

135-
void ThreadPoolBase::Run(int index) noexcept {
195+
void ThreadPoolBase::Run(
196+
int index,
197+
const std::function<OnThreadStartFn> &on_thread_start) noexcept {
136198
this_thread_pool_ = this;
137199
this_thread_idx_ = index;
138-
OnThreadStart(index);
139-
detail::CallAtExit([&]() { OnThreadStop(index); });
200+
std::any scope;
201+
if (on_thread_start)
202+
scope = on_thread_start(index);
140203
std::unique_lock lock(mtx_);
141204
while (!shutdown_pending_ || !tasks_.empty()) {
142205
cv_.wait(lock, [&]() { return shutdown_pending_ || !tasks_.empty(); });

dali/core/exec/thread_pool_base_test.cc

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
// limitations under the License.
1414

1515
#include <gtest/gtest.h>
16+
#include <iostream>
1617
#include "dali/core/exec/thread_pool_base.h"
1718
#include "dali/core/format.h"
19+
#include "dali/test/timing.h"
1820

1921
namespace dali {
2022

@@ -34,7 +36,15 @@ TEST(NewThreadPool, Scrap) {
3436
});
3537
}
3638

37-
TEST(NewThreadPool, ErrorNotStarted) {
39+
TEST(NewThreadPool, IncrementalJobScrap) {
40+
EXPECT_NO_THROW({
41+
IncrementalJob job;
42+
job.AddTask([]() {});
43+
job.Scrap();
44+
});
45+
}
46+
47+
TEST(NewThreadPool, ErrorJobNotStarted) {
3848
try {
3949
Job job;
4050
job.AddTask([]() {});
@@ -45,6 +55,16 @@ TEST(NewThreadPool, ErrorNotStarted) {
4555
GTEST_FAIL() << "Expected a logic error.";
4656
}
4757

58+
TEST(NewThreadPool, ErrorIncrementalJobNotStarted) {
59+
try {
60+
IncrementalJob job;
61+
job.AddTask([]() {});
62+
} catch (std::logic_error &e) {
63+
EXPECT_NE(nullptr, strstr(e.what(), "The job is not empty"));
64+
return;
65+
}
66+
GTEST_FAIL() << "Expected a logic error.";
67+
}
4868

4969
TEST(NewThreadPool, RunJobInSeries) {
5070
Job job;
@@ -84,9 +104,78 @@ TEST(NewThreadPool, RunJobInThreadPool) {
84104
EXPECT_EQ(c, 3);
85105
}
86106

107+
TEST(NewThreadPool, RunIncrementalJobInThreadPool) {
108+
ThreadPoolBase tp(4);
109+
IncrementalJob job;
110+
std::atomic_int a = 0, b = 0, c = 0;
111+
job.AddTask([&]() {
112+
a += 1;
113+
});
114+
job.AddTask([&]() {
115+
b += 2;
116+
});
117+
job.Run(tp, false);
87118

88-
TEST(NewThreadPool, RethrowMultipleErrors) {
89-
Job job;
119+
for (int i = 0; (a.load() != 1 || b.load() != 2) && i < 100000; i++)
120+
std::this_thread::sleep_for(std::chrono::microseconds(10));
121+
ASSERT_TRUE(a.load() == 1 && b.load() == 2) << "The job didn't start.";
122+
123+
job.AddTask([&]() {
124+
c += 3;
125+
});
126+
job.Run(tp, true);
127+
EXPECT_EQ(a.load(), 1);
128+
EXPECT_EQ(b.load(), 2);
129+
EXPECT_EQ(c.load(), 3);
130+
}
131+
132+
133+
TEST(NewThreadPool, RunLargeIncrementalJobInThreadPool) {
134+
ThreadPoolBase tp(4);
135+
const int max_attempts = 10;
136+
for (int attempt = 0; attempt < max_attempts; attempt++) {
137+
IncrementalJob job;
138+
std::atomic_int acc = 0;
139+
const int total_tasks = 40000;
140+
const int batch_size = 100;
141+
for (int i = 0; i < total_tasks; i += batch_size) {
142+
for (int j = i; j < i + batch_size; j++) {
143+
job.AddTask([&, j] {
144+
acc += j;
145+
});
146+
}
147+
job.Run(tp, false);
148+
if (i == 0) {
149+
for (int spin = 0; acc.load() == 0 && spin < 100000; spin++)
150+
std::this_thread::sleep_for(std::chrono::microseconds(10));
151+
ASSERT_NE(acc.load(), 0) << "The job isn't running in the background.";
152+
}
153+
}
154+
int target_value = total_tasks * (total_tasks - 1) / 2;
155+
if (acc.load() == target_value) {
156+
if (attempt == max_attempts - 1) {
157+
FAIL() << "The job always finishes before a call to wait.";
158+
} else {
159+
std::cerr << "The job shouldn't have completed yet - retrying.\n";
160+
}
161+
job.Wait();
162+
continue;
163+
}
164+
job.Run(tp, true);
165+
EXPECT_EQ(acc.load(), target_value);
166+
break;
167+
}
168+
}
169+
170+
template <typename JobType>
171+
class NewThreadPoolJobTest : public ::testing::Test {};
172+
173+
using JobTypes = ::testing::Types<Job, IncrementalJob>;
174+
TYPED_TEST_SUITE(NewThreadPoolJobTest, JobTypes);
175+
176+
177+
TYPED_TEST(NewThreadPoolJobTest, RethrowMultipleErrors) {
178+
TypeParam job;
90179
ThreadPoolBase tp(4);
91180
job.AddTask([&]() {
92181
throw std::runtime_error("Runtime");
@@ -110,8 +199,8 @@ void SyncPrint(Args&& ...args) {
110199
printf("%s", str.c_str());
111200
}
112201

113-
TEST(NewThreadPool, Reentrant) {
114-
Job job;
202+
TYPED_TEST(NewThreadPoolJobTest, Reentrant) {
203+
TypeParam job;
115204
ThreadPoolBase tp(1); // must not hang with just one thread
116205
std::atomic_int outer{0}, inner{0};
117206
for (int i = 0; i < 10; i++) {
@@ -141,4 +230,46 @@ TEST(NewThreadPool, Reentrant) {
141230
job.Run(tp, true);
142231
}
143232

233+
TYPED_TEST(NewThreadPoolJobTest, JobPerf) {
234+
using JobType = TypeParam;
235+
ThreadPoolBase tp(4);
236+
auto do_test = [&](int jobs, int tasks) {
237+
std::vector<int> v(tasks);
238+
auto start = test::perf_timer::now();
239+
for (int i = 0; i < jobs; i++) {
240+
JobType j;
241+
for (int t = 0; t < tasks; t++) {
242+
j.AddTask([&, t]() {
243+
v[t]++;
244+
});
245+
}
246+
j.Run(tp, true);
247+
}
248+
auto end = test::perf_timer::now();
249+
250+
for (int t = 0; t < tasks; t++)
251+
EXPECT_EQ(v[t], jobs) << "Tasks didn't do their job";
252+
print(
253+
std::cout, "Ran ", jobs, " jobs of ", tasks, " tasks each in ",
254+
test::format_time(end - start), "\n");
255+
256+
return end - start;
257+
};
258+
259+
int total_tasks = 100000;
260+
int jobs0 = 10000, tasks0 = total_tasks / jobs0;
261+
auto time0 = do_test(jobs0, tasks0);
262+
int jobs1 = 100, tasks1 = total_tasks / jobs1;
263+
auto time1 = do_test(jobs1, tasks1);
264+
265+
// time0 = task_time * total_tasks + job_overhead * jobs0
266+
// time1 = task_time * total_tasks + job_overhead * jobs1
267+
// hence
268+
// time0 - time1 = job_overhead * (jobs0 - jobs1)
269+
// job_overhead = (time0 - time1) / (jobs0 - jobs1)
270+
271+
double job_overhead = test::seconds(time0 - time1) / (jobs0 - jobs1);
272+
print(std::cout, "Job overhead ", test::format_time(job_overhead), "\n");
273+
}
274+
144275
} // namespace dali

0 commit comments

Comments
 (0)