Skip to content
This repository was archived by the owner on Dec 8, 2021. It is now read-only.

Commit 3d5fe45

Browse files
authored
fix: fix racy conditions during shutdown (#250)
* bug: fix race condition during shutdown We were checking if the completion queue was shutdown and then scheduling work, but not atomically.
1 parent 76644c7 commit 3d5fe45

8 files changed

+72
-54
lines changed
205 Bytes
Binary file not shown.
7.14 KB
Binary file not shown.

google/cloud/completion_queue.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ google::cloud::future<StatusOr<std::chrono::system_clock::time_point>>
9090
CompletionQueue::MakeDeadlineTimer(
9191
std::chrono::system_clock::time_point deadline) {
9292
auto op = std::make_shared<AsyncTimerFuture>(impl_->CreateAlarm());
93-
void* tag = impl_->RegisterOperation(op);
94-
if (tag != nullptr) {
95-
op->Set(impl_->cq(), deadline, tag);
96-
}
93+
impl_->StartOperation(
94+
op, [&](void* tag) { op->Set(impl_->cq(), deadline, tag); });
9795
return op->GetFuture();
9896
}
9997

google/cloud/completion_queue.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,9 @@ class CompletionQueue {
113113
std::unique_ptr<grpc::ClientContext> context) {
114114
auto op =
115115
std::make_shared<internal::AsyncUnaryRpcFuture<Request, Response>>();
116-
void* tag = impl_->RegisterOperation(op);
117-
if (tag != nullptr) {
116+
impl_->StartOperation(op, [&](void* tag) {
118117
op->Start(async_call, std::move(context), request, &impl_->cq(), tag);
119-
}
118+
});
120119
return op->GetFuture();
121120
}
122121

google/cloud/completion_queue_test.cc

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,14 @@ TEST(CompletionQueueTest, RunAsync) {
333333
// Sets up a timer that reschedules itself and verifies we can shut down
334334
// cleanly whether we call `CancelAll()` on the queue first or not.
335335
namespace {
336-
void RunAndReschedule(CompletionQueue& cq, bool ok) {
336+
using TimerFuture = future<StatusOr<std::chrono::system_clock::time_point>>;
337+
338+
void RunAndReschedule(CompletionQueue& cq, bool ok,
339+
std::chrono::seconds duration) {
337340
if (ok) {
338-
cq.MakeRelativeTimer(std::chrono::seconds(1))
339-
.then([&cq](future<StatusOr<std::chrono::system_clock::time_point>>
340-
result) { RunAndReschedule(cq, result.get().ok()); });
341+
cq.MakeRelativeTimer(duration).then([&cq, duration](TimerFuture result) {
342+
RunAndReschedule(cq, result.get().ok(), duration);
343+
});
341344
}
342345
}
343346
} // namespace
@@ -346,17 +349,40 @@ TEST(CompletionQueueTest, ShutdownWithReschedulingTimer) {
346349
CompletionQueue cq;
347350
std::thread t([&cq] { cq.Run(); });
348351

349-
RunAndReschedule(cq, /*ok=*/true);
352+
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1));
350353

351354
cq.Shutdown();
352355
t.join();
353356
}
354357

358+
TEST(CompletionQueueTest, ShutdownWithFastReschedulingTimer) {
359+
auto constexpr kThreadCount = 32;
360+
auto constexpr kTimerCount = 100;
361+
CompletionQueue cq;
362+
std::vector<std::thread> threads(kThreadCount);
363+
std::generate_n(threads.begin(), threads.size(),
364+
[&cq] { return std::thread([&cq] { cq.Run(); }); });
365+
366+
for (int i = 0; i != kTimerCount; ++i) {
367+
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(0));
368+
}
369+
370+
promise<void> wait;
371+
cq.MakeRelativeTimer(std::chrono::milliseconds(1)).then([&wait](TimerFuture) {
372+
wait.set_value();
373+
});
374+
wait.get_future().get();
375+
cq.Shutdown();
376+
for (auto& t : threads) {
377+
t.join();
378+
}
379+
}
380+
355381
TEST(CompletionQueueTest, CancelAndShutdownWithReschedulingTimer) {
356382
CompletionQueue cq;
357383
std::thread t([&cq] { cq.Run(); });
358384

359-
RunAndReschedule(cq, /*ok=*/true);
385+
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1));
360386

361387
cq.CancelAll();
362388
cq.Shutdown();

google/cloud/internal/async_read_stream_impl.h

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,14 @@ class AsyncReadStreamImpl
157157
context_ = std::move(context);
158158
cq_ = std::move(cq);
159159
auto callback = std::make_shared<NotifyStart>(this->shared_from_this());
160-
void* tag = cq_->RegisterOperation(std::move(callback));
161-
// @note If `tag == nullptr` the `CompletionQueue` has been `Shutdown()`.
162-
// We leave `reader_` null in this case; other methods must make the
163-
// same `tag != nullptr` check prior to accessing `reader_`. This is
164-
// safe since `Shutdown()` cannot be undone.
165-
if (tag != nullptr) {
160+
cq_->StartOperation(std::move(callback), [&](void* tag) {
161+
// @note If the the `CompletionQueue` has been `Shutdown()` this lambda is
162+
// never called. We leave `reader_` null in this case; other methods
163+
// must make the same `tag != nullptr` check prior to accessing
164+
// `reader_`. This is safe since `Shutdown()` cannot be undone.
166165
reader_ = async_call(context_.get(), request, &cq_->cq());
167166
reader_->StartCall(tag);
168-
}
167+
});
169168
}
170169

171170
/// Cancel the current streaming read RPC.
@@ -202,10 +201,8 @@ class AsyncReadStreamImpl
202201

203202
auto callback = std::make_shared<NotifyRead>(this->shared_from_this());
204203
auto response = &callback->response;
205-
void* tag = cq_->RegisterOperation(std::move(callback));
206-
if (tag != nullptr) {
207-
reader_->Read(response, tag);
208-
}
204+
cq_->StartOperation(std::move(callback),
205+
[&](void* tag) { reader_->Read(response, tag); });
209206
}
210207

211208
/// Handle the result of a `Read()` call.
@@ -252,10 +249,8 @@ class AsyncReadStreamImpl
252249

253250
auto callback = std::make_shared<NotifyFinish>(this->shared_from_this());
254251
auto status = &callback->status;
255-
void* tag = cq_->RegisterOperation(std::move(callback));
256-
if (tag != nullptr) {
257-
reader_->Finish(status, tag);
258-
}
252+
cq_->StartOperation(std::move(callback),
253+
[&](void* tag) { reader_->Finish(status, tag); });
259254
}
260255

261256
/// Handle the result of a Finish() request.
@@ -292,10 +287,8 @@ class AsyncReadStreamImpl
292287

293288
auto callback = std::make_shared<NotifyDiscard>(this->shared_from_this());
294289
auto response = &callback->response;
295-
void* tag = cq_->RegisterOperation(std::move(callback));
296-
if (tag != nullptr) {
297-
reader_->Read(response, tag);
298-
}
290+
cq_->StartOperation(std::move(callback),
291+
[&](void* tag) { reader_->Read(response, tag); });
299292
}
300293

301294
/// Handle the result of a Discard() call.

google/cloud/internal/completion_queue_impl.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,6 @@ std::unique_ptr<grpc::Alarm> CompletionQueueImpl::CreateAlarm() const {
7373
return google::cloud::internal::make_unique<grpc::Alarm>();
7474
}
7575

76-
void* CompletionQueueImpl::RegisterOperation(
77-
std::shared_ptr<AsyncGrpcOperation> op) {
78-
void* tag = op.get();
79-
std::unique_lock<std::mutex> lk(mu_);
80-
if (shutdown_) {
81-
lk.unlock();
82-
op->Notify(/*ok=*/false);
83-
return nullptr;
84-
}
85-
auto ins =
86-
pending_ops_.emplace(reinterpret_cast<std::intptr_t>(tag), std::move(op));
87-
// After this point we no longer need the lock, so release it.
88-
lk.unlock();
89-
if (ins.second) {
90-
return tag;
91-
}
92-
google::cloud::internal::ThrowRuntimeError(
93-
"assertion failure: insertion should succeed");
94-
}
95-
9676
std::shared_ptr<AsyncGrpcOperation> CompletionQueueImpl::FindOperation(
9777
void* tag) {
9878
std::lock_guard<std::mutex> lk(mu_);

google/cloud/internal/completion_queue_impl.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,30 @@ class CompletionQueueImpl {
247247
/// The underlying gRPC completion queue.
248248
grpc::CompletionQueue& cq() { return cq_; }
249249

250-
/// Add a new asynchronous operation to the completion queue.
251-
void* RegisterOperation(std::shared_ptr<AsyncGrpcOperation> op);
250+
/// Atomically add a new operation to the completion queue and start it.
251+
template <typename Callable,
252+
typename std::enable_if<
253+
google::cloud::internal::is_invocable<Callable, void*>::value,
254+
int>::type = 0>
255+
void StartOperation(std::shared_ptr<AsyncGrpcOperation> op,
256+
Callable&& start) {
257+
void* tag = op.get();
258+
std::unique_lock<std::mutex> lk(mu_);
259+
if (shutdown_) {
260+
lk.unlock();
261+
op->Notify(/*ok=*/false);
262+
return;
263+
}
264+
auto ins = pending_ops_.emplace(reinterpret_cast<std::intptr_t>(tag),
265+
std::move(op));
266+
if (ins.second) {
267+
start(tag);
268+
lk.unlock();
269+
return;
270+
}
271+
google::cloud::internal::ThrowRuntimeError(
272+
"assertion failure: insertion should succeed");
273+
}
252274

253275
protected:
254276
/// Return the asynchronous operation associated with @p tag.

0 commit comments

Comments
 (0)