Skip to content

Commit 9fb8a43

Browse files
authored
Merge pull request #48582 from Dr15Jones/improveConcurrencyTest
Fix race condition in SerialTaskQueue
2 parents e833ddb + 60d5546 commit 9fb8a43

File tree

6 files changed

+165
-199
lines changed

6 files changed

+165
-199
lines changed

FWCore/Concurrency/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Exactly how to _dispose_ of the class can be customized via the virtual function
1111
## `edm::WaitingTask`
1212
This class inherits from `edm::TaskBase` and represents a set of code to be run once other activities have completed. This includes the ability to hold a `std::exception_ptr` which can hold an exception which was generated in a dependent task.
1313

14-
A raw pointer to a `edm::WaitingTask` is not supposed to be handled directly. Instead, one should use the helpers `edm::WaitingTaskHolder`, `edm::WaitingTaskWithArenaHolder` or `edm::WaitingTaskList` to properly manage the internal reference count such that when the count drops to 0 the `execute()` method will be run followed by `recycle()`. In addition, these helper classes will handled passing along any `edm::exception_ptr` generated from a dependent task.
14+
A raw pointer to a `edm::WaitingTask` is not supposed to be handled directly. Instead, one should use the helpers `edm::WaitingTaskHolder`, `edm::WaitingTaskWithArenaHolder` or `edm::WaitingTaskList` to properly manage the internal reference count such that when the count drops to 0 the `execute()` method will be run followed by `recycle()`. In addition, these helper classes will handle passing along any `edm::exception_ptr` generated from a dependent task.
1515

1616
The easiest way to create an `edm::WaitingTask` is to call `edm::make_waiting_task` and pass in a lambda of the form `void(std::exception_ptr const*)`.
1717
```C++
@@ -34,7 +34,7 @@ In the case where one is doing a synchronous wait on a series of asynchronous ta
3434

3535
Note that the function `wait` will rethrow any exception stored in `finalTask`. There is an alternative function named `waitNoThrow` which will return the `std::exception_ptr`.
3636

37-
WARNING: It important that the finalTask not execute before completion of the construction of all `WaitingTaskHolders` that will be constructed directly from finalTask. The following would be a bug:
37+
WARNING: It is important that the finalTask not be executed before completion of the construction of all `WaitingTaskHolders` that will be constructed directly from finalTask. The following would be a bug:
3838

3939
```C++
4040
oneapi::tbb::task_group group;

FWCore/Concurrency/interface/SerialTaskQueue.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,14 @@
6666
namespace edm {
6767
class SerialTaskQueue {
6868
public:
69-
SerialTaskQueue() : m_taskChosen(false), m_pauseCount{0} {}
69+
SerialTaskQueue() : m_pauseCount{0}, m_taskChosen{false}, m_pickingNextTask{false} {}
7070

7171
SerialTaskQueue(SerialTaskQueue&& iOther)
7272
: m_tasks(std::move(iOther.m_tasks)),
73+
m_pauseCount(iOther.m_pauseCount.exchange(0)),
7374
m_taskChosen(iOther.m_taskChosen.exchange(false)),
74-
m_pauseCount(iOther.m_pauseCount.exchange(0)) {
75-
assert(m_tasks.empty() and m_taskChosen == false);
75+
m_pickingNextTask(false) {
76+
assert(m_tasks.empty() and m_taskChosen == false and iOther.m_pickingNextTask == false);
7677
}
7778
SerialTaskQueue(const SerialTaskQueue&) = delete;
7879
const SerialTaskQueue& operator=(const SerialTaskQueue&) = delete;
@@ -159,8 +160,9 @@ namespace edm {
159160

160161
// ---------- member data --------------------------------
161162
oneapi::tbb::concurrent_queue<TaskBase*> m_tasks;
162-
std::atomic<bool> m_taskChosen;
163163
std::atomic<unsigned long> m_pauseCount;
164+
std::atomic<bool> m_taskChosen;
165+
std::atomic<bool> m_pickingNextTask;
164166
};
165167

166168
template <typename T>

FWCore/Concurrency/src/SerialTaskQueue.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "FWCore/Concurrency/interface/SerialTaskQueue.h"
1919

2020
#include "FWCore/Utilities/interface/Likely.h"
21+
#include "FWCore/Utilities/interface/make_sentry.h"
2122

2223
using namespace edm;
2324

@@ -86,25 +87,24 @@ SerialTaskQueue::TaskBase* SerialTaskQueue::finishedTask() {
8687
}
8788

8889
SerialTaskQueue::TaskBase* SerialTaskQueue::pickNextTask() {
90+
if UNLIKELY (0 != m_pauseCount)
91+
return nullptr;
8992
bool expect = false;
90-
if LIKELY (0 == m_pauseCount and m_taskChosen.compare_exchange_strong(expect, true)) {
93+
//need pop task and setting m_taskChosen to be atomic to avoid
94+
// case where thread pauses just after try_pop failed but then
95+
// a task is added and that call fails the check on m_taskChosen
96+
while (not m_pickingNextTask.compare_exchange_strong(expect, true)) {
97+
expect = false;
98+
}
99+
auto sentry = edm::make_sentry(&m_pickingNextTask, [](auto* v) { v->store(false); });
100+
101+
if LIKELY (m_taskChosen.compare_exchange_strong(expect, true)) {
91102
TaskBase* t = nullptr;
92103
if LIKELY (m_tasks.try_pop(t)) {
93104
return t;
94105
}
95106
//no task was actually pulled
96107
m_taskChosen.store(false);
97-
98-
//was a new entry added after we called 'try_pop' but before we did the clear?
99-
expect = false;
100-
if (not m_tasks.empty() and m_taskChosen.compare_exchange_strong(expect, true)) {
101-
t = nullptr;
102-
if (m_tasks.try_pop(t)) {
103-
return t;
104-
}
105-
//no task was still pulled since a different thread beat us to it
106-
m_taskChosen.store(false);
107-
}
108108
}
109109
return nullptr;
110110
}

FWCore/Concurrency/test/test2_catch2_limitedtaskqueue.cc

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,49 @@
1111
#include <memory>
1212
#include <atomic>
1313
#include <thread>
14+
#include <mutex>
1415
#include "oneapi/tbb/task_arena.h"
1516
#include "FWCore/Concurrency/interface/WaitingTask.h"
17+
#include "FWCore/Concurrency/interface/FinalWaitingTask.h"
18+
#include "FWCore/Concurrency/interface/WaitingTaskHolder.h"
1619
#include "FWCore/Concurrency/interface/LimitedTaskQueue.h"
1720
#include "FWCore/Concurrency/interface/FunctorTask.h"
1821

1922
using namespace std::chrono_literals;
2023

24+
namespace {
25+
std::mutex g_requiresMutex;
26+
27+
}
28+
//catch2 REQUIRE is not thread safe
29+
#define SAFE_REQUIRE(__var__) \
30+
{ \
31+
std::lock_guard g{g_requiresMutex}; \
32+
REQUIRE(__var__); \
33+
}
34+
2135
TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
2236
SECTION("push") {
2337
{
2438
std::atomic<unsigned int> count{0};
2539
edm::LimitedTaskQueue queue{1};
2640
{
27-
std::atomic<int> waitingTasks{3};
2841
oneapi::tbb::task_group group;
29-
queue.push(group, [&count, &waitingTasks] {
42+
edm::FinalWaitingTask lastTask(group);
43+
edm::WaitingTaskHolder waitingTask(group, &lastTask);
44+
queue.push(group, [&count, waitingTask] {
3045
REQUIRE(count++ == 0u);
3146
std::this_thread::sleep_for(10us);
32-
--waitingTasks;
3347
});
34-
queue.push(group, [&count, &waitingTasks] {
48+
queue.push(group, [&count, waitingTask] {
3549
REQUIRE(count++ == 1u);
3650
std::this_thread::sleep_for(10us);
37-
--waitingTasks;
3851
});
39-
queue.push(group, [&count, &waitingTasks] {
52+
queue.push(group, [&count, lastTask = std::move(waitingTask)] {
4053
REQUIRE(count++ == 2u);
4154
std::this_thread::sleep_for(10us);
42-
--waitingTasks;
4355
});
44-
do {
45-
group.wait();
46-
} while (0 != waitingTasks.load());
56+
lastTask.wait();
4757
REQUIRE(count == 3u);
4858
}
4959
}
@@ -52,29 +62,25 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
5262
constexpr unsigned int kMax = 2;
5363
edm::LimitedTaskQueue queue{kMax};
5464
{
55-
std::atomic<int> waitingTasks{3};
5665
oneapi::tbb::task_group group;
57-
queue.push(group, [&count, &waitingTasks, kMax] {
58-
REQUIRE(count++ < kMax);
66+
edm::FinalWaitingTask lastTask(group);
67+
edm::WaitingTaskHolder waitingTask(group, &lastTask);
68+
queue.push(group, [&count, waitingTask, kMax] {
69+
SAFE_REQUIRE(count++ < kMax);
5970
std::this_thread::sleep_for(10us);
6071
--count;
61-
--waitingTasks;
6272
});
63-
queue.push(group, [&count, &waitingTasks, kMax] {
64-
REQUIRE(count++ < kMax);
73+
queue.push(group, [&count, waitingTask, kMax] {
74+
SAFE_REQUIRE(count++ < kMax);
6575
std::this_thread::sleep_for(10us);
6676
--count;
67-
--waitingTasks;
6877
});
69-
queue.push(group, [&count, &waitingTasks, kMax] {
70-
REQUIRE(count++ < kMax);
78+
queue.push(group, [&count, lastTask = std::move(waitingTask), kMax] {
79+
SAFE_REQUIRE(count++ < kMax);
7180
std::this_thread::sleep_for(10us);
7281
--count;
73-
--waitingTasks;
7482
});
75-
do {
76-
group.wait();
77-
} while (0 != waitingTasks);
83+
lastTask.wait();
7884
REQUIRE(count == 0u);
7985
}
8086
}
@@ -85,48 +91,45 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
8591
edm::LimitedTaskQueue queue{1};
8692
{
8793
{
88-
std::atomic<int> waitingTasks{3};
8994
oneapi::tbb::task_group group;
95+
edm::FinalWaitingTask lastTask(group);
96+
edm::WaitingTaskHolder waitingTask(group, &lastTask);
97+
9098
edm::LimitedTaskQueue::Resumer resumer;
9199
std::atomic<bool> resumerSet{false};
92100
std::exception_ptr e1;
93-
queue.pushAndPause(
94-
group, [&resumer, &resumerSet, &count, &waitingTasks, &e1](edm::LimitedTaskQueue::Resumer iResumer) {
95-
resumer = std::move(iResumer);
96-
resumerSet = true;
97-
try {
98-
REQUIRE(++count == 1u);
99-
} catch (...) {
100-
e1 = std::current_exception();
101-
}
102-
--waitingTasks;
103-
});
101+
queue.pushAndPause(group,
102+
[&resumer, &resumerSet, &count, waitingTask, &e1](edm::LimitedTaskQueue::Resumer iResumer) {
103+
resumer = std::move(iResumer);
104+
resumerSet = true;
105+
try {
106+
SAFE_REQUIRE(++count == 1u);
107+
} catch (...) {
108+
e1 = std::current_exception();
109+
}
110+
});
104111
std::exception_ptr e2;
105-
queue.push(group, [&count, &waitingTasks, &e2] {
112+
queue.push(group, [&count, waitingTask, &e2] {
106113
try {
107-
REQUIRE(++count == 2u);
114+
SAFE_REQUIRE(++count == 2u);
108115
} catch (...) {
109116
e2 = std::current_exception();
110117
}
111-
--waitingTasks;
112118
});
113119
std::exception_ptr e3;
114-
queue.push(group, [&count, &waitingTasks, &e3] {
120+
queue.push(group, [&count, lastTask = std::move(waitingTask), &e3] {
115121
try {
116-
REQUIRE(++count == 3u);
122+
SAFE_REQUIRE(++count == 3u);
117123
} catch (...) {
118124
e3 = std::current_exception();
119125
}
120-
--waitingTasks;
121126
});
122127
std::this_thread::sleep_for(100us);
123128
REQUIRE(2u >= count);
124129
while (not resumerSet) {
125130
}
126-
REQUIRE(resumer.resume());
127-
do {
128-
group.wait();
129-
} while (0 != waitingTasks.load());
131+
SAFE_REQUIRE(resumer.resume());
132+
lastTask.wait();
130133
REQUIRE(count == 3u);
131134
if (e1) {
132135
std::rethrow_exception(e1);
@@ -147,54 +150,46 @@ TEST_CASE("LimitedTaskQueue", "[LimitedTaskQueue]") {
147150
edm::LimitedTaskQueue queue{kMax};
148151
unsigned int index = 100;
149152
const unsigned int nTasks = 1000;
150-
//catch2 REQUIRE is not thread safe
151-
std::mutex mutex;
152153
while (0 != --index) {
153-
std::atomic<int> waiting{1};
154+
edm::FinalWaitingTask lastTask(group);
155+
154156
std::atomic<unsigned int> count{0};
155157
std::atomic<unsigned int> nRunningTasks{0};
156158
std::atomic<bool> waitToStart{true};
157159
{
158-
group.run([&queue, &waitToStart, &group, &waiting, &count, &nRunningTasks, &mutex, kMax] {
160+
edm::WaitingTaskHolder waitingTask(group, &lastTask);
161+
162+
group.run([&queue, &waitToStart, &group, waitingTask, &count, &nRunningTasks, kMax] {
159163
while (waitToStart) {
160164
}
161165
for (unsigned int i = 0; i < nTasks; ++i) {
162-
++waiting;
163-
queue.push(group, [&count, &waiting, &nRunningTasks, &mutex, kMax] {
164-
std::shared_ptr<std::atomic<int>> guardAgain{&waiting, [](auto* v) { --(*v); }};
166+
queue.push(group, [&count, waitingTask, &nRunningTasks, kMax] {
165167
auto nrt = nRunningTasks++;
166168
if (nrt >= kMax) {
167169
std::cout << "ERROR " << nRunningTasks << " >= " << kMax << std::endl;
168-
std::lock_guard lock{mutex};
169-
REQUIRE(nrt < kMax);
170+
SAFE_REQUIRE(nrt < kMax);
170171
}
171172
++count;
172173
--nRunningTasks;
173174
});
174175
}
175176
});
176-
group.run([&queue, &waitToStart, &group, &waiting, &count, &nRunningTasks, &mutex, kMax] {
177+
group.run([&queue, &waitToStart, &group, waitingTask, &count, &nRunningTasks, kMax] {
177178
waitToStart = false;
178179
for (unsigned int i = 0; i < nTasks; ++i) {
179-
++waiting;
180-
queue.push(group, [&count, &waiting, &nRunningTasks, &mutex, kMax] {
181-
std::shared_ptr<std::atomic<int>> guardAgain{&waiting, [](auto* v) { --(*v); }};
180+
queue.push(group, [&count, waitingTask, &nRunningTasks, kMax] {
182181
auto nrt = nRunningTasks++;
183182
if (nrt >= kMax) {
184183
std::cout << "ERROR " << nRunningTasks << " >= " << kMax << std::endl;
185-
std::lock_guard lock{mutex};
186-
REQUIRE(nrt < kMax);
184+
SAFE_REQUIRE(nrt < kMax);
187185
}
188186
++count;
189187
--nRunningTasks;
190188
});
191189
}
192-
--waiting;
193190
});
194191
}
195-
do {
196-
group.wait();
197-
} while (0 != waiting.load());
192+
lastTask.wait();
198193
REQUIRE(nRunningTasks == 0u);
199194
REQUIRE(2 * nTasks == count);
200195
}

0 commit comments

Comments
 (0)