Skip to content

Commit 754841e

Browse files
authored
grpc: reestablish connection and start stream (#31)
1 parent d628004 commit 754841e

File tree

7 files changed

+121
-23
lines changed

7 files changed

+121
-23
lines changed

cpp2sky/internal/async_client.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <grpcpp/grpcpp.h>
1919

2020
#include <memory>
21+
#include <queue>
2122

2223
using google::protobuf::Message;
2324

@@ -59,6 +60,27 @@ class AsyncClient {
5960
* Peer address of current gRPC client..
6061
*/
6162
virtual std::string peerAddress() = 0;
63+
64+
/**
65+
* Drain pending messages
66+
*/
67+
virtual void drainPendingMessages(
68+
std::queue<RequestType>& pending_messages) = 0;
69+
70+
/**
71+
* Reset stream if it is living.
72+
*/
73+
virtual void resetStream() = 0;
74+
75+
/**
76+
* Start stream if there is no living stream.
77+
*/
78+
virtual void startStream() = 0;
79+
80+
/**
81+
* The number of drained pending messages.
82+
*/
83+
virtual size_t numOfMessages() = 0;
6284
};
6385

6486
enum class Operation : uint8_t {
@@ -105,7 +127,8 @@ class AsyncStreamFactory {
105127
* Create async stream entity
106128
*/
107129
virtual AsyncStreamPtr<RequestType> create(
108-
AsyncClient<RequestType, ResponseType>* client) = 0;
130+
AsyncClient<RequestType, ResponseType>* client,
131+
std::queue<RequestType>& drained_messages) = 0;
109132
};
110133

111134
template <class RequestType, class ResponseType>

source/grpc_async_client_impl.cc

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,23 @@ GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient(
4343
AsyncStreamFactory<TracerRequestType, TracerResponseType>& factory,
4444
std::shared_ptr<grpc::ChannelCredentials> cred, std::string address,
4545
std::string token)
46-
: token_(token), address_(address), factory_(factory), cq_(cq) {
47-
stub_ = std::make_unique<TracerStubImpl>(grpc::CreateChannel(address, cred));
48-
stream_ = factory_.create(this);
49-
stream_->startStream();
46+
: token_(token),
47+
factory_(factory),
48+
cq_(cq),
49+
channel_(grpc::CreateChannel(address, cred)) {
50+
stub_ = std::make_unique<TracerStubImpl>(channel_);
51+
startStream();
5052
}
5153

5254
void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) {
53-
GPR_ASSERT(stream_ != nullptr);
55+
if (!stream_) {
56+
drained_messages_.emplace(message);
57+
gpr_log(GPR_INFO,
58+
"No active stream, inserted message into draining message queue. "
59+
"pending message size: %ld",
60+
drained_messages_.size());
61+
return;
62+
}
5463
stream_->sendMessage(message);
5564
}
5665

@@ -64,16 +73,45 @@ GrpcAsyncSegmentReporterClient::createWriter(grpc::ClientContext* ctx,
6473
return stub_->createWriter(ctx, response, cq_, tag);
6574
}
6675

76+
void GrpcAsyncSegmentReporterClient::startStream() {
77+
resetStream();
78+
79+
// Try to establish connection.
80+
channel_->GetState(true);
81+
stream_ = factory_.create(this, drained_messages_);
82+
stream_->startStream();
83+
}
84+
85+
void GrpcAsyncSegmentReporterClient::drainPendingMessages(
86+
std::queue<TracerRequestType>& pending_messages) {
87+
const auto pending_messages_size = pending_messages.size();
88+
while (!pending_messages.empty()) {
89+
auto msg = pending_messages.front();
90+
pending_messages.pop();
91+
drained_messages_.emplace(msg);
92+
}
93+
gpr_log(GPR_INFO, "%ld pending messages drained.", pending_messages_size);
94+
}
95+
6796
GrpcAsyncSegmentReporterStream::GrpcAsyncSegmentReporterStream(
68-
AsyncClient<TracerRequestType, TracerResponseType>* client)
69-
: client_(client) {}
97+
AsyncClient<TracerRequestType, TracerResponseType>* client,
98+
std::queue<TracerRequestType>& drained_messages)
99+
: client_(client) {
100+
const auto drained_messages_size = drained_messages.size();
101+
while (!drained_messages.empty()) {
102+
auto msg = drained_messages.front();
103+
pending_messages_.emplace(msg);
104+
drained_messages.pop();
105+
}
106+
gpr_log(GPR_INFO, "%ld drained messages inserted into pending messages.",
107+
drained_messages_size);
108+
}
70109

71110
GrpcAsyncSegmentReporterStream::~GrpcAsyncSegmentReporterStream() {
72111
{
73112
std::unique_lock<std::mutex> lck_(mux_);
74113
cond_.wait(lck_, [this] { return pending_messages_.empty(); });
75114
}
76-
77115
ctx_.TryCancel();
78116
request_writer_->Finish(&status_, toTag(&finish_));
79117
}
@@ -132,6 +170,7 @@ bool GrpcAsyncSegmentReporterStream::handleOperation(Operation incoming_op) {
132170
} else if (state_ == Operation::Finished) {
133171
gpr_log(GPR_INFO, "Stream closed with http status: %d",
134172
grpcStatusToGenericHttpStatus(status_.error_code()));
173+
client_->drainPendingMessages(pending_messages_);
135174
if (!status_.ok()) {
136175
gpr_log(GPR_ERROR, "%s", status_.error_message().c_str());
137176
}
@@ -141,11 +180,13 @@ bool GrpcAsyncSegmentReporterStream::handleOperation(Operation incoming_op) {
141180
}
142181

143182
AsyncStreamPtr<TracerRequestType> GrpcAsyncSegmentReporterStreamFactory::create(
144-
AsyncClient<TracerRequestType, TracerResponseType>* client) {
183+
AsyncClient<TracerRequestType, TracerResponseType>* client,
184+
std::queue<TracerRequestType>& drained_messages) {
145185
if (client == nullptr) {
146186
return nullptr;
147187
}
148-
return std::make_shared<GrpcAsyncSegmentReporterStream>(client);
188+
return std::make_shared<GrpcAsyncSegmentReporterStream>(client,
189+
drained_messages);
149190
}
150191

151192
} // namespace cpp2sky

source/grpc_async_client_impl.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class TracerStubImpl final
4444
std::unique_ptr<TraceSegmentReportService::Stub> stub_;
4545
};
4646

47+
class GrpcAsyncSegmentReporterStream;
48+
4749
class GrpcAsyncSegmentReporterClient final
4850
: public AsyncClient<TracerRequestType, TracerResponseType> {
4951
public:
@@ -59,18 +61,28 @@ class GrpcAsyncSegmentReporterClient final
5961
std::unique_ptr<grpc::ClientAsyncWriter<TracerRequestType>> createWriter(
6062
grpc::ClientContext* ctx, TracerResponseType* response,
6163
void* tag) override;
64+
void drainPendingMessages(
65+
std::queue<TracerRequestType>& pending_messages) override;
66+
void resetStream() override {
67+
if (stream_) {
68+
stream_.reset();
69+
stream_ = nullptr;
70+
}
71+
}
72+
void startStream() override;
73+
size_t numOfMessages() override { return drained_messages_.size(); }
6274

6375
private:
6476
std::string token_;
6577
std::string address_;
6678
AsyncStreamFactory<TracerRequestType, TracerResponseType>& factory_;
6779
TracerStubPtr<TracerRequestType, TracerResponseType> stub_;
6880
grpc::CompletionQueue* cq_;
81+
std::shared_ptr<grpc::Channel> channel_;
6982
AsyncStreamPtr<TracerRequestType> stream_;
83+
std::queue<TracerRequestType> drained_messages_;
7084
};
7185

72-
class GrpcAsyncSegmentReporterStream;
73-
7486
struct TaggedStream {
7587
Operation operation;
7688
GrpcAsyncSegmentReporterStream* stream;
@@ -83,7 +95,8 @@ class GrpcAsyncSegmentReporterStream final
8395
: public AsyncStream<TracerRequestType> {
8496
public:
8597
GrpcAsyncSegmentReporterStream(
86-
AsyncClient<TracerRequestType, TracerResponseType>* client);
98+
AsyncClient<TracerRequestType, TracerResponseType>* client,
99+
std::queue<TracerRequestType>& drained_messages);
87100
~GrpcAsyncSegmentReporterStream() override;
88101

89102
// AsyncStream
@@ -115,7 +128,8 @@ class GrpcAsyncSegmentReporterStreamFactory final
115128
public:
116129
// AsyncStreamFactory
117130
AsyncStreamPtr<TracerRequestType> create(
118-
AsyncClient<TracerRequestType, TracerResponseType>* client) override;
131+
AsyncClient<TracerRequestType, TracerResponseType>* client,
132+
std::queue<TracerRequestType>& drained_messages) override;
119133
};
120134

121135
} // namespace cpp2sky

source/tracer_impl.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ TracerImpl::TracerImpl(TracerConfig& config,
2222
std::shared_ptr<grpc::ChannelCredentials> cred,
2323
GrpcAsyncSegmentReporterStreamFactory& factory)
2424
: th_([this] { this->run(); }) {
25-
client_ = new GrpcAsyncSegmentReporterClient(
25+
client_ = std::make_unique<GrpcAsyncSegmentReporterClient>(
2626
&cq_, factory, cred, config.address(), config.token());
2727
}
2828

2929
TracerImpl::~TracerImpl() {
30-
delete client_;
30+
client_.reset();
3131
th_.join();
3232
cq_.Shutdown();
3333
}
@@ -50,10 +50,11 @@ void TracerImpl::run() {
5050
}
5151
TaggedStream* t_stream = deTag(got_tag);
5252
if (!ok) {
53+
client_->resetStream();
5354
continue;
5455
}
5556
if (!t_stream->stream->handleOperation(t_stream->operation)) {
56-
return;
57+
client_->startStream();
5758
}
5859
}
5960
}

source/tracer_impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
namespace cpp2sky {
2323

24+
using TracerRequestType = SegmentObject;
25+
using TracerResponseType = Commands;
26+
2427
class TracerImpl : public Tracer {
2528
public:
2629
TracerImpl(TracerConfig& config,
@@ -33,7 +36,7 @@ class TracerImpl : public Tracer {
3336
private:
3437
void run();
3538

36-
GrpcAsyncSegmentReporterClient* client_;
39+
AsyncClientPtr<TracerRequestType, TracerResponseType> client_;
3740
grpc::CompletionQueue cq_;
3841
std::thread th_;
3942
};

test/grpc_async_client_test.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using testing::_;
2828
class GrpcAsyncSegmentReporterClientTest : public testing::Test {
2929
public:
3030
GrpcAsyncSegmentReporterClientTest() {
31-
EXPECT_CALL(factory_, create(_));
31+
EXPECT_CALL(factory_, create(_, _));
3232
EXPECT_CALL(*stream_, startStream());
3333
client_ = std::make_unique<GrpcAsyncSegmentReporterClient>(
3434
&cq_, factory_, grpc::InsecureChannelCredentials(), address_, token_);
@@ -51,4 +51,14 @@ TEST_F(GrpcAsyncSegmentReporterClientTest, SendMessageTest) {
5151
client_->sendMessage(fake_message);
5252
}
5353

54+
TEST_F(GrpcAsyncSegmentReporterClientTest, MessageDrainTest) {
55+
std::queue<TracerRequestType> fake_pending_messages;
56+
for (int i = 0; i < 3; ++i) {
57+
fake_pending_messages.emplace(SegmentObject());
58+
}
59+
client_->drainPendingMessages(fake_pending_messages);
60+
EXPECT_EQ(fake_pending_messages.size(), 0);
61+
EXPECT_EQ(client_->numOfMessages(), 3);
62+
}
63+
5464
} // namespace cpp2sky

test/mocks.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,20 @@ class MockAsyncStream : public AsyncStream<RequestType> {
3838
MOCK_METHOD(void, sendMessage, (RequestType));
3939
MOCK_METHOD(std::string, peerAddress, ());
4040
MOCK_METHOD(bool, handleOperation, (Operation));
41+
MOCK_METHOD(size_t, numOfMessages, ());
4142
};
4243

4344
template <class RequestType, class ResponseType>
4445
class MockAsyncClient : public AsyncClient<RequestType, ResponseType> {
4546
public:
46-
MOCK_METHOD(void, sendMessage, (Message&));
47+
MOCK_METHOD(void, sendMessage, (RequestType));
4748
MOCK_METHOD(std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>,
4849
createWriter, (grpc::ClientContext*, ResponseType*, void*));
4950
MOCK_METHOD(std::string, peerAddress, ());
51+
MOCK_METHOD(void, drainPendingMessages, (std::queue<RequestType>&));
52+
MOCK_METHOD(void, resetStream, ());
53+
MOCK_METHOD(void, startStream, ());
54+
MOCK_METHOD(size_t, numOfMessages, ());
5055
};
5156

5257
template <class RequestType, class ResponseType>
@@ -55,9 +60,10 @@ class MockAsyncStreamFactory
5560
public:
5661
using AsyncClientType = AsyncClient<RequestType, ResponseType>;
5762
MockAsyncStreamFactory(AsyncStreamPtr<RequestType> stream) : stream_(stream) {
58-
ON_CALL(*this, create(_)).WillByDefault(Return(stream_));
63+
ON_CALL(*this, create(_, _)).WillByDefault(Return(stream_));
5964
}
60-
MOCK_METHOD(AsyncStreamPtr<RequestType>, create, (AsyncClientType*));
65+
MOCK_METHOD(AsyncStreamPtr<RequestType>, create,
66+
(AsyncClientType*, std::queue<RequestType>&));
6167

6268
private:
6369
AsyncStreamPtr<RequestType> stream_;

0 commit comments

Comments
 (0)