1313// limitations under the License.
1414
1515#include < gtest/gtest.h>
16+ #include < iostream>
1617#include " dali/core/exec/thread_pool_base.h"
1718#include " dali/core/format.h"
19+ #include " dali/test/timing.h"
1820
1921namespace dali {
2022
@@ -34,7 +36,15 @@ TEST(NewThreadPool, Scrap) {
3436 });
3537}
3638
37- TEST (NewThreadPool, ErrorNotStarted) {
39+ TEST (NewThreadPool, IncrementalJobScrap) {
40+ EXPECT_NO_THROW ({
41+ IncrementalJob job;
42+ job.AddTask ([]() {});
43+ job.Scrap ();
44+ });
45+ }
46+
47+ TEST (NewThreadPool, ErrorJobNotStarted) {
3848 try {
3949 Job job;
4050 job.AddTask ([]() {});
@@ -45,6 +55,16 @@ TEST(NewThreadPool, ErrorNotStarted) {
4555 GTEST_FAIL () << " Expected a logic error." ;
4656}
4757
58+ TEST (NewThreadPool, ErrorIncrementalJobNotStarted) {
59+ try {
60+ IncrementalJob job;
61+ job.AddTask ([]() {});
62+ } catch (std::logic_error &e) {
63+ EXPECT_NE (nullptr , strstr (e.what (), " The job is not empty" ));
64+ return ;
65+ }
66+ GTEST_FAIL () << " Expected a logic error." ;
67+ }
4868
4969TEST (NewThreadPool, RunJobInSeries) {
5070 Job job;
@@ -84,9 +104,78 @@ TEST(NewThreadPool, RunJobInThreadPool) {
84104 EXPECT_EQ (c, 3 );
85105}
86106
107+ TEST (NewThreadPool, RunIncrementalJobInThreadPool) {
108+ ThreadPoolBase tp (4 );
109+ IncrementalJob job;
110+ std::atomic_int a = 0 , b = 0 , c = 0 ;
111+ job.AddTask ([&]() {
112+ a += 1 ;
113+ });
114+ job.AddTask ([&]() {
115+ b += 2 ;
116+ });
117+ job.Run (tp, false );
87118
88- TEST (NewThreadPool, RethrowMultipleErrors) {
89- Job job;
119+ for (int i = 0 ; (a.load () != 1 || b.load () != 2 ) && i < 100000 ; i++)
120+ std::this_thread::sleep_for (std::chrono::microseconds (10 ));
121+ ASSERT_TRUE (a.load () == 1 && b.load () == 2 ) << " The job didn't start." ;
122+
123+ job.AddTask ([&]() {
124+ c += 3 ;
125+ });
126+ job.Run (tp, true );
127+ EXPECT_EQ (a.load (), 1 );
128+ EXPECT_EQ (b.load (), 2 );
129+ EXPECT_EQ (c.load (), 3 );
130+ }
131+
132+
133+ TEST (NewThreadPool, RunLargeIncrementalJobInThreadPool) {
134+ ThreadPoolBase tp (4 );
135+ const int max_attempts = 10 ;
136+ for (int attempt = 0 ; attempt < max_attempts; attempt++) {
137+ IncrementalJob job;
138+ std::atomic_int acc = 0 ;
139+ const int total_tasks = 40000 ;
140+ const int batch_size = 100 ;
141+ for (int i = 0 ; i < total_tasks; i += batch_size) {
142+ for (int j = i; j < i + batch_size; j++) {
143+ job.AddTask ([&, j] {
144+ acc += j;
145+ });
146+ }
147+ job.Run (tp, false );
148+ if (i == 0 ) {
149+ for (int spin = 0 ; acc.load () == 0 && spin < 100000 ; spin++)
150+ std::this_thread::sleep_for (std::chrono::microseconds (10 ));
151+ ASSERT_NE (acc.load (), 0 ) << " The job isn't running in the background." ;
152+ }
153+ }
154+ int target_value = total_tasks * (total_tasks - 1 ) / 2 ;
155+ if (acc.load () == target_value) {
156+ if (attempt == max_attempts - 1 ) {
157+ FAIL () << " The job always finishes before a call to wait." ;
158+ } else {
159+ std::cerr << " The job shouldn't have completed yet - retrying.\n " ;
160+ }
161+ job.Wait ();
162+ continue ;
163+ }
164+ job.Run (tp, true );
165+ EXPECT_EQ (acc.load (), target_value);
166+ break ;
167+ }
168+ }
169+
170+ template <typename JobType>
171+ class NewThreadPoolJobTest : public ::testing::Test {};
172+
173+ using JobTypes = ::testing::Types<Job, IncrementalJob>;
174+ TYPED_TEST_SUITE (NewThreadPoolJobTest, JobTypes);
175+
176+
177+ TYPED_TEST (NewThreadPoolJobTest, RethrowMultipleErrors) {
178+ TypeParam job;
90179 ThreadPoolBase tp (4 );
91180 job.AddTask ([&]() {
92181 throw std::runtime_error (" Runtime" );
@@ -110,8 +199,8 @@ void SyncPrint(Args&& ...args) {
110199 printf (" %s" , str.c_str ());
111200}
112201
113- TEST (NewThreadPool , Reentrant) {
114- Job job;
202+ TYPED_TEST (NewThreadPoolJobTest , Reentrant) {
203+ TypeParam job;
115204 ThreadPoolBase tp (1 ); // must not hang with just one thread
116205 std::atomic_int outer{0 }, inner{0 };
117206 for (int i = 0 ; i < 10 ; i++) {
@@ -141,4 +230,46 @@ TEST(NewThreadPool, Reentrant) {
141230 job.Run (tp, true );
142231}
143232
233+ TYPED_TEST (NewThreadPoolJobTest, JobPerf) {
234+ using JobType = TypeParam;
235+ ThreadPoolBase tp (4 );
236+ auto do_test = [&](int jobs, int tasks) {
237+ std::vector<int > v (tasks);
238+ auto start = test::perf_timer::now ();
239+ for (int i = 0 ; i < jobs; i++) {
240+ JobType j;
241+ for (int t = 0 ; t < tasks; t++) {
242+ j.AddTask ([&, t]() {
243+ v[t]++;
244+ });
245+ }
246+ j.Run (tp, true );
247+ }
248+ auto end = test::perf_timer::now ();
249+
250+ for (int t = 0 ; t < tasks; t++)
251+ EXPECT_EQ (v[t], jobs) << " Tasks didn't do their job" ;
252+ print (
253+ std::cout, " Ran " , jobs, " jobs of " , tasks, " tasks each in " ,
254+ test::format_time (end - start), " \n " );
255+
256+ return end - start;
257+ };
258+
259+ int total_tasks = 100000 ;
260+ int jobs0 = 10000 , tasks0 = total_tasks / jobs0;
261+ auto time0 = do_test (jobs0, tasks0);
262+ int jobs1 = 100 , tasks1 = total_tasks / jobs1;
263+ auto time1 = do_test (jobs1, tasks1);
264+
265+ // time0 = task_time * total_tasks + job_overhead * jobs0
266+ // time1 = task_time * total_tasks + job_overhead * jobs1
267+ // hence
268+ // time0 - time1 = job_overhead * (jobs0 - jobs1)
269+ // job_overhead = (time0 - time1) / (jobs0 - jobs1)
270+
271+ double job_overhead = test::seconds (time0 - time1) / (jobs0 - jobs1);
272+ print (std::cout, " Job overhead " , test::format_time (job_overhead), " \n " );
273+ }
274+
144275} // namespace dali
0 commit comments