66//
77
88#include < iostream>
9+ #include < sstream>
910#include < catch.hpp>
1011#include < chrono>
1112#include < memory>
1213#include < atomic>
13- #include < thread>
1414#include < iostream>
1515#include " oneapi/tbb/task.h"
16+ #include " oneapi/tbb/global_control.h"
17+ #include " oneapi/tbb/task_arena.h"
1618#include " FWCore/Concurrency/interface/SerialTaskQueueChain.h"
1719#include " FWCore/Concurrency/interface/FinalWaitingTask.h"
1820#include " FWCore/Concurrency/interface/WaitingTaskHolder.h"
1921
2022using namespace std ::chrono_literals;
2123
22- namespace {
23- void join_thread (std::thread* iThread) {
24- if (iThread->joinable ()) {
25- iThread->join ();
26- }
27- }
28- } // namespace
29-
3024TEST_CASE (" SerialTaskQueueChain" , " [SerialTaskQueueChain]" ) {
3125 SECTION (" push" ) {
3226 std::atomic<unsigned int > count{0 };
@@ -92,34 +86,56 @@ TEST_CASE("SerialTaskQueueChain", "[SerialTaskQueueChain]") {
9286 SECTION (" stress test" ) {
9387 std::vector<std::shared_ptr<edm::SerialTaskQueue>> queues = {std::make_shared<edm::SerialTaskQueue>(),
9488 std::make_shared<edm::SerialTaskQueue>()};
95- edm::SerialTaskQueueChain chain (queues);
96- unsigned int index = 100 ;
97- const unsigned int nTasks = 1000 ;
98- while (0 != --index) {
99- oneapi::tbb::task_group group;
100- edm::FinalWaitingTask lastTask (group);
101- std::atomic<unsigned int > count{0 };
102- std::atomic<bool > waitToStart{true };
103- {
104- edm::WaitingTaskHolder lastHolder (group, &lastTask);
105- std::thread pushThread ([&chain, &waitToStart, &group, &count, lastHolder] {
106- while (waitToStart.load ()) {
107- };
108- for (unsigned int i = 0 ; i < nTasks; ++i) {
109- chain.push (group, [&count, lastHolder] { ++count; });
110- }
111- });
112- waitToStart = false ;
113- for (unsigned int i = 0 ; i < nTasks; ++i) {
114- chain.push (group, [&count, lastHolder] { ++count; });
89+ REQUIRE (2 <= oneapi::tbb::this_task_arena::max_concurrency ());
90+ oneapi::tbb::task_arena arena (2 );
91+ arena.execute ([&]() {
92+ edm::SerialTaskQueueChain chain (queues);
93+ unsigned int index = 100 ;
94+ const unsigned int nTasks = 1000 ;
95+ while (0 != --index) {
96+ oneapi::tbb::task_group group;
97+ edm::FinalWaitingTask lastTask (group);
98+ std::atomic<unsigned int > count{0 };
99+ std::atomic<unsigned int > waitToStart{2 };
100+ {
101+ edm::WaitingTaskHolder lastHolder (group, &lastTask);
102+
103+ group.run ([&chain, &waitToStart, &group, &count, lastHolder, index] {
104+ --waitToStart;
105+ while (waitToStart.load () != 0 )
106+ ;
107+ std::ostringstream ss;
108+ ss << " start task 1, index: " << index << " \n " ;
109+ std::cout << ss.str () << std::flush;
110+ for (unsigned int i = 0 ; i < nTasks; ++i) {
111+ chain.push (group, [&count, lastHolder] { ++count; });
112+ }
113+ ss.str (std::string ());
114+ ss << " stop task 1, index: " << index << " \n " ;
115+ std::cout << ss.str () << std::flush;
116+ });
117+ group.run ([&chain, &waitToStart, &group, &count, lastHolder, index] {
118+ --waitToStart;
119+ while (waitToStart.load () != 0 )
120+ ;
121+ std::ostringstream ss;
122+ ss << " start task 2, index: " << index << " \n " ;
123+ std::cout << ss.str () << std::flush;
124+ for (unsigned int i = 0 ; i < nTasks; ++i) {
125+ chain.push (group, [&count, lastHolder] { ++count; });
126+ }
127+ ss.str (std::string ());
128+ ss << " stop task 2, index: " << index << " \n " ;
129+ std::cout << ss.str () << std::flush;
130+ });
115131 }
116- lastHolder.doneWaiting (std::exception_ptr ());
117- std::shared_ptr<std::thread>(&pushThread, join_thread);
132+ std::cout << " Waiting for tasks to finish, index: " << index << " \n " << std::flush;
133+ lastTask.wait ();
134+ REQUIRE (2 * nTasks == count);
118135 }
119- lastTask.wait ();
120- REQUIRE (2 * nTasks == count);
121- }
122- while (chain.outstandingTasks () != 0 )
123- ;
136+ CHECK (0 == chain.outstandingTasks ());
137+ while (chain.outstandingTasks () != 0 )
138+ ;
139+ });
124140 }
125- }
141+ }
0 commit comments