Skip to content

Commit 92241a6

Browse files
authored
Fix ConnServer stall
1 parent 950e923 commit 92241a6

File tree

13 files changed

+86
-30
lines changed

13 files changed

+86
-30
lines changed

include/comm/ConnServer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ namespace comm {
8888

8989
std::unique_ptr<Connection> accept();
9090

91+
std::unique_ptr<Connection> negotiate_protocol(std::shared_ptr<Connection> conn);
92+
9193
private:
9294

9395
ConnServerConfig _config;

include/comm/Protocol.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88

99
namespace comm {
1010

11-
ENUM_FLAGS(Protocol, uint8_t)
11+
// We are setting Protocol to be uint32_t here because
12+
// otherwise there will be padding, and we
13+
// observed that sometimes that padding is causing issues
14+
// on the client side when decoding the protocol.
15+
// TODO: Make HelloMessage and Protocol a Protobuf message.
16+
// to prevent this kind of issues, and to prevent
17+
// potential issues with endianness.
18+
ENUM_FLAGS(Protocol, uint32_t)
1219
{
1320
None = 0,
1421
TCP = 1 << 0,

include/util/Utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
*
3+
* @copyright Copyright (c) 2022 ApertureData Inc.
4+
*
5+
*/
6+
7+
#pragma once
8+
9+
#include <util/TypeName.h>
10+
#include <iostream>
11+
12+
template<typename T>
13+
void print_binary_data(const T& data)
14+
{
15+
std::cout << "sizeof(" << type_name(data) << "): " << sizeof(T) << std::endl;
16+
for (int i = 0; i < sizeof(T); ++i)
17+
{
18+
printf("%02X ", reinterpret_cast<uint8_t *>(&data)[i] & 0xff);
19+
}
20+
std::cout << std::endl;
21+
}

src/comm/ConnServer.cc

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,53 +114,74 @@ ConnServer::ConnServer(int port, ConnServerConfig config) :
114114

115115
ConnServer::~ConnServer() = default;
116116

117-
std::unique_ptr<Connection> ConnServer::accept()
117+
// c contains a TCPConnection, unencrypted, connection to a client.
118+
// The ConnServer will implement the protocol negotiation.
119+
// This right now is a simple handshake, design for ApertureDB Server use-case.
120+
// This protocol can be a virtual method in the future to support arbitrary protocols.
121+
std::unique_ptr<Connection> ConnServer::negotiate_protocol(std::shared_ptr<Connection> conn)
118122
{
119-
auto connected_socket = TCPSocket::accept(_listening_socket);
120-
121-
auto tcp_connection = std::unique_ptr<TCPConnection>(
122-
new TCPConnection(std::move(connected_socket), _config.metrics));
123+
auto tcp_connection = std::static_pointer_cast<TCPConnection>(conn);
123124

124125
auto response = tcp_connection->recv_message();
125126

126-
if (response.length() != sizeof(HelloMessage)) {
127+
if (response.length() != sizeof(HelloMessage))
128+
{
127129
THROW_EXCEPTION(ProtocolError);
128130
}
129131

130132
auto client_hello_message = reinterpret_cast<const HelloMessage*>(response.data());
131133

132134
HelloMessage server_hello_message;
133135

134-
if (client_hello_message->version != PROTOCOL_VERSION) {
136+
if (client_hello_message->version != PROTOCOL_VERSION)
137+
{
135138
server_hello_message.version = 0;
136139
server_hello_message.protocol = Protocol::None;
137140
}
138-
else {
141+
else
142+
{
139143
server_hello_message.version = PROTOCOL_VERSION;
140144
server_hello_message.protocol = client_hello_message->protocol & _config.allowed_protocols;
141145
}
142146

143-
tcp_connection->send_message(reinterpret_cast<uint8_t*>(&server_hello_message), sizeof(server_hello_message));
147+
tcp_connection->send_message(
148+
reinterpret_cast<uint8_t *>(&server_hello_message),
149+
sizeof(server_hello_message));
144150

145-
if (server_hello_message.version == 0) {
151+
if (server_hello_message.version == 0)
152+
{
146153
THROW_EXCEPTION(ProtocolError, "Protocol version mismatch");
147154
}
148155

149-
if ((server_hello_message.protocol & Protocol::TLS) == Protocol::TLS) {
156+
if ((server_hello_message.protocol & Protocol::TLS) == Protocol::TLS)
157+
{
150158
auto tcp_socket = tcp_connection->release_socket();
151159

152160
auto tls_socket = TLSSocket::create(std::move(tcp_socket), _ssl_ctx);
153161

154162
tls_socket->accept();
155163

156-
return std::unique_ptr<TLSConnection>(
157-
new TLSConnection(std::move(tls_socket), _config.metrics));
164+
return std::make_unique<TLSConnection>(std::move(tls_socket), _config.metrics);
158165
}
159-
else if ((server_hello_message.protocol & Protocol::TCP) == Protocol::TCP) {
166+
else if ((server_hello_message.protocol & Protocol::TCP) == Protocol::TCP)
167+
{
168+
auto tcp_socket = tcp_connection->release_socket();
160169
// Nothing to do, already using TCP
161-
return tcp_connection;
170+
// return tcp_connection;
171+
return std::make_unique<TCPConnection>(std::move(tcp_socket), _config.metrics);
162172
}
163-
else {
173+
else
174+
{
164175
THROW_EXCEPTION(ProtocolError, "Protocol negotiation failed");
165176
}
166177
}
178+
179+
std::unique_ptr<Connection> ConnServer::accept()
180+
{
181+
auto connected_socket = TCPSocket::accept(_listening_socket);
182+
183+
auto tcp_connection = std::make_unique<TCPConnection>(
184+
std::move(connected_socket), _config.metrics);
185+
186+
return tcp_connection;
187+
}

src/comm/Connection.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ Connection::Connection(ConnMetrics* metrics)
3939
: _max_buffer_size(DEFAULT_BUFFER_SIZE)
4040
, _metrics(metrics)
4141
{
42-
4342
}
4443

4544
Connection::~Connection() = default;

src/comm/TCPSocket.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ std::unique_ptr<TCPSocket> TCPSocket::accept(const std::unique_ptr<TCPSocket>& l
5757

5858
errno = 0;
5959
int connected_socket = ::accept(listening_socket->_socket_fd, reinterpret_cast<sockaddr*>(&clnt_addr), &len);
60+
6061
int errno_r = errno;
6162
if (connected_socket < 0) {
6263
if (errno_r == EAGAIN) {

src/comm/TLSSocket.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44

55
#include "comm/TLSSocket.h"
6+
#include <openssl/err.h>
67

78
#include <cstring>
89
#include <netdb.h>
@@ -38,6 +39,7 @@ void TLSSocket::accept()
3839
errno = 0;
3940
auto result = ::SSL_accept(_ssl);
4041
int errno_r = errno;
42+
ERR_print_errors_fp(stdout);
4143
if (result < 1) {
4244
THROW_EXCEPTION(TLSError, errno_r, "SSL_accept()", result);
4345
}

src/comm/TLSSocket.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ namespace comm {
3434
explicit TLSSocket(std::unique_ptr<TCPSocket> tcp_socket, SSL* ssl);
3535

3636
SSL* _ssl{nullptr};
37+
38+
// Even if this member is not used, it is necessary to keep it alive
39+
// until the destructor is called.
3740
std::unique_ptr<TCPSocket> _tcp_socket;
3841
};
3942

test/AuthEnabledVDMSServer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ AuthEnabledVDMSServer::AuthEnabledVDMSServer(int port, AuthEnabledVDMSServerConf
3838
{
3939
auto thread_function = [&]()
4040
{
41-
std::shared_ptr<comm::Connection> server_conn = _server.accept();
41+
std::shared_ptr<comm::Connection> server_conn = _server.negotiate_protocol(_server.accept());
4242

4343
while (!_stop_signal) {
4444
protobufs::queryMessage protobuf_request;

test/TCPConnectionTests.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(TCPConnectionTests, SyncMessages)
3333

3434
barrier.wait();
3535

36-
auto server_conn = server.accept();
36+
auto server_conn = server.negotiate_protocol(server.accept());
3737

3838
for (int i = 0; i < NUMBER_OF_MESSAGES; ++i) {
3939
//Recieve something
@@ -81,7 +81,7 @@ TEST(TCPConnectionTests, AsyncMessages)
8181

8282
barrier.wait();
8383

84-
auto server_conn = server.accept();
84+
auto server_conn = server.negotiate_protocol(server.accept());
8585

8686
for (int i = 0; i < NUMBER_OF_MESSAGES; ++i) {
8787
//Send something
@@ -130,7 +130,7 @@ TEST(TCPConnectionTests, ServerShutdownRecv)
130130

131131
barrier.wait();
132132

133-
server.accept();
133+
auto server_conn = server.negotiate_protocol(server.accept());
134134
});
135135

136136
comm::ConnClient conn_client({"localhost", SERVER_PORT_INTERCHANGE});
@@ -159,7 +159,7 @@ TEST(TCPConnectionTests, SendArrayInts)
159159

160160
barrier.wait();
161161

162-
auto server_conn = server.accept();
162+
auto server_conn = server.negotiate_protocol(server.accept());
163163

164164
server_conn->send_message(reinterpret_cast<const uint8_t*>(arr), sizeof(arr));
165165
});

0 commit comments

Comments
 (0)