1111#include < memory>
1212#include < atomic>
1313#include < thread>
14+ #include < mutex>
1415#include " oneapi/tbb/task_arena.h"
1516#include " FWCore/Concurrency/interface/WaitingTask.h"
17+ #include " FWCore/Concurrency/interface/FinalWaitingTask.h"
18+ #include " FWCore/Concurrency/interface/WaitingTaskHolder.h"
1619#include " FWCore/Concurrency/interface/LimitedTaskQueue.h"
1720#include " FWCore/Concurrency/interface/FunctorTask.h"
1821
1922using namespace std ::chrono_literals;
2023
24+ namespace {
25+ std::mutex g_requiresMutex;
26+
27+ }
28+ // catch2 REQUIRE is not thread safe
29+ #define SAFE_REQUIRE (__var__ ) \
30+ { \
31+ std::lock_guard g{g_requiresMutex}; \
32+ REQUIRE (__var__); \
33+ }
34+
2135TEST_CASE (" LimitedTaskQueue" , " [LimitedTaskQueue]" ) {
2236 SECTION (" push" ) {
2337 {
2438 std::atomic<unsigned int > count{0 };
2539 edm::LimitedTaskQueue queue{1 };
2640 {
27- std::atomic<int > waitingTasks{3 };
2841 oneapi::tbb::task_group group;
29- queue.push (group, [&count, &waitingTasks] {
42+ edm::FinalWaitingTask lastTask (group);
43+ edm::WaitingTaskHolder waitingTask (group, &lastTask);
44+ queue.push (group, [&count, waitingTask] {
3045 REQUIRE (count++ == 0u );
3146 std::this_thread::sleep_for (10us);
32- --waitingTasks;
3347 });
34- queue.push (group, [&count, &waitingTasks ] {
48+ queue.push (group, [&count, waitingTask ] {
3549 REQUIRE (count++ == 1u );
3650 std::this_thread::sleep_for (10us);
37- --waitingTasks;
3851 });
39- queue.push (group, [&count, &waitingTasks ] {
52+ queue.push (group, [&count, lastTask = std::move (waitingTask) ] {
4053 REQUIRE (count++ == 2u );
4154 std::this_thread::sleep_for (10us);
42- --waitingTasks;
4355 });
44- do {
45- group.wait ();
46- } while (0 != waitingTasks.load ());
56+ lastTask.wait ();
4757 REQUIRE (count == 3u );
4858 }
4959 }
@@ -52,29 +62,25 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
5262 constexpr unsigned int kMax = 2 ;
5363 edm::LimitedTaskQueue queue{kMax };
5464 {
55- std::atomic<int > waitingTasks{3 };
5665 oneapi::tbb::task_group group;
57- queue.push (group, [&count, &waitingTasks, kMax ] {
58- REQUIRE (count++ < kMax );
66+ edm::FinalWaitingTask lastTask (group);
67+ edm::WaitingTaskHolder waitingTask (group, &lastTask);
68+ queue.push (group, [&count, waitingTask, kMax ] {
69+ SAFE_REQUIRE (count++ < kMax );
5970 std::this_thread::sleep_for (10us);
6071 --count;
61- --waitingTasks;
6272 });
63- queue.push (group, [&count, &waitingTasks , kMax ] {
64- REQUIRE (count++ < kMax );
73+ queue.push (group, [&count, waitingTask , kMax ] {
74+ SAFE_REQUIRE (count++ < kMax );
6575 std::this_thread::sleep_for (10us);
6676 --count;
67- --waitingTasks;
6877 });
69- queue.push (group, [&count, &waitingTasks , kMax ] {
70- REQUIRE (count++ < kMax );
78+ queue.push (group, [&count, lastTask = std::move (waitingTask) , kMax ] {
79+ SAFE_REQUIRE (count++ < kMax );
7180 std::this_thread::sleep_for (10us);
7281 --count;
73- --waitingTasks;
7482 });
75- do {
76- group.wait ();
77- } while (0 != waitingTasks);
83+ lastTask.wait ();
7884 REQUIRE (count == 0u );
7985 }
8086 }
@@ -85,48 +91,45 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
8591 edm::LimitedTaskQueue queue{1 };
8692 {
8793 {
88- std::atomic<int > waitingTasks{3 };
8994 oneapi::tbb::task_group group;
95+ edm::FinalWaitingTask lastTask (group);
96+ edm::WaitingTaskHolder waitingTask (group, &lastTask);
97+
9098 edm::LimitedTaskQueue::Resumer resumer;
9199 std::atomic<bool > resumerSet{false };
92100 std::exception_ptr e1 ;
93- queue.pushAndPause (
94- group, [&resumer, &resumerSet, &count, &waitingTasks, &e1 ](edm::LimitedTaskQueue::Resumer iResumer) {
95- resumer = std::move (iResumer);
96- resumerSet = true ;
97- try {
98- REQUIRE (++count == 1u );
99- } catch (...) {
100- e1 = std::current_exception ();
101- }
102- --waitingTasks;
103- });
101+ queue.pushAndPause (group,
102+ [&resumer, &resumerSet, &count, waitingTask, &e1 ](edm::LimitedTaskQueue::Resumer iResumer) {
103+ resumer = std::move (iResumer);
104+ resumerSet = true ;
105+ try {
106+ SAFE_REQUIRE (++count == 1u );
107+ } catch (...) {
108+ e1 = std::current_exception ();
109+ }
110+ });
104111 std::exception_ptr e2 ;
105- queue.push (group, [&count, &waitingTasks , &e2 ] {
112+ queue.push (group, [&count, waitingTask , &e2 ] {
106113 try {
107- REQUIRE (++count == 2u );
114+ SAFE_REQUIRE (++count == 2u );
108115 } catch (...) {
109116 e2 = std::current_exception ();
110117 }
111- --waitingTasks;
112118 });
113119 std::exception_ptr e3 ;
114- queue.push (group, [&count, &waitingTasks , &e3 ] {
120+ queue.push (group, [&count, lastTask = std::move (waitingTask) , &e3 ] {
115121 try {
116- REQUIRE (++count == 3u );
122+ SAFE_REQUIRE (++count == 3u );
117123 } catch (...) {
118124 e3 = std::current_exception ();
119125 }
120- --waitingTasks;
121126 });
122127 std::this_thread::sleep_for (100us);
123128 REQUIRE (2u >= count);
124129 while (not resumerSet) {
125130 }
126- REQUIRE (resumer.resume ());
127- do {
128- group.wait ();
129- } while (0 != waitingTasks.load ());
131+ SAFE_REQUIRE (resumer.resume ());
132+ lastTask.wait ();
130133 REQUIRE (count == 3u );
131134 if (e1 ) {
132135 std::rethrow_exception (e1 );
@@ -147,54 +150,46 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
147150 edm::LimitedTaskQueue queue{kMax };
148151 unsigned int index = 100 ;
149152 const unsigned int nTasks = 1000 ;
150- // catch2 REQUIRE is not thread safe
151- std::mutex mutex;
152153 while (0 != --index) {
153- std::atomic<int > waiting{1 };
154+ edm::FinalWaitingTask lastTask (group);
155+
154156 std::atomic<unsigned int > count{0 };
155157 std::atomic<unsigned int > nRunningTasks{0 };
156158 std::atomic<bool > waitToStart{true };
157159 {
158- group.run ([&queue, &waitToStart, &group, &waiting, &count, &nRunningTasks, &mutex, kMax ] {
160+ edm::WaitingTaskHolder waitingTask (group, &lastTask);
161+
162+ group.run ([&queue, &waitToStart, &group, waitingTask, &count, &nRunningTasks, kMax ] {
159163 while (waitToStart) {
160164 }
161165 for (unsigned int i = 0 ; i < nTasks; ++i) {
162- ++waiting;
163- queue.push (group, [&count, &waiting, &nRunningTasks, &mutex, kMax ] {
164- std::shared_ptr<std::atomic<int >> guardAgain{&waiting, [](auto * v) { --(*v); }};
166+ queue.push (group, [&count, waitingTask, &nRunningTasks, kMax ] {
165167 auto nrt = nRunningTasks++;
166168 if (nrt >= kMax ) {
167169 std::cout << " ERROR " << nRunningTasks << " >= " << kMax << std::endl;
168- std::lock_guard lock{mutex};
169- REQUIRE (nrt < kMax );
170+ SAFE_REQUIRE (nrt < kMax );
170171 }
171172 ++count;
172173 --nRunningTasks;
173174 });
174175 }
175176 });
176- group.run ([&queue, &waitToStart, &group, &waiting , &count, &nRunningTasks, &mutex , kMax ] {
177+ group.run ([&queue, &waitToStart, &group, waitingTask , &count, &nRunningTasks, kMax ] {
177178 waitToStart = false ;
178179 for (unsigned int i = 0 ; i < nTasks; ++i) {
179- ++waiting;
180- queue.push (group, [&count, &waiting, &nRunningTasks, &mutex, kMax ] {
181- std::shared_ptr<std::atomic<int >> guardAgain{&waiting, [](auto * v) { --(*v); }};
180+ queue.push (group, [&count, waitingTask, &nRunningTasks, kMax ] {
182181 auto nrt = nRunningTasks++;
183182 if (nrt >= kMax ) {
184183 std::cout << " ERROR " << nRunningTasks << " >= " << kMax << std::endl;
185- std::lock_guard lock{mutex};
186- REQUIRE (nrt < kMax );
184+ SAFE_REQUIRE (nrt < kMax );
187185 }
188186 ++count;
189187 --nRunningTasks;
190188 });
191189 }
192- --waiting;
193190 });
194191 }
195- do {
196- group.wait ();
197- } while (0 != waiting.load ());
192+ lastTask.wait ();
198193 REQUIRE (nRunningTasks == 0u );
199194 REQUIRE (2 * nTasks == count);
200195 }
0 commit comments