@@ -114,53 +114,74 @@ ConnServer::ConnServer(int port, ConnServerConfig config) :
114114
115115ConnServer::~ConnServer () = default ;
116116
117- std::unique_ptr<Connection> ConnServer::accept ()
117+ // c contains a TCPConnection, unencrypted, connection to a client.
118+ // The ConnServer will implement the protocol negotiation.
119+ // This right now is a simple handshake, design for ApertureDB Server use-case.
120+ // This protocol can be a virtual method in the future to support arbitrary protocols.
121+ std::unique_ptr<Connection> ConnServer::negotiate_protocol (std::shared_ptr<Connection> conn)
118122{
119- auto connected_socket = TCPSocket::accept (_listening_socket);
120-
121- auto tcp_connection = std::unique_ptr<TCPConnection>(
122- new TCPConnection (std::move (connected_socket), _config.metrics ));
123+ auto tcp_connection = std::static_pointer_cast<TCPConnection>(conn);
123124
124125 auto response = tcp_connection->recv_message ();
125126
126- if (response.length () != sizeof (HelloMessage)) {
127+ if (response.length () != sizeof (HelloMessage))
128+ {
127129 THROW_EXCEPTION (ProtocolError);
128130 }
129131
130132 auto client_hello_message = reinterpret_cast <const HelloMessage*>(response.data ());
131133
132134 HelloMessage server_hello_message;
133135
134- if (client_hello_message->version != PROTOCOL_VERSION) {
136+ if (client_hello_message->version != PROTOCOL_VERSION)
137+ {
135138 server_hello_message.version = 0 ;
136139 server_hello_message.protocol = Protocol::None;
137140 }
138- else {
141+ else
142+ {
139143 server_hello_message.version = PROTOCOL_VERSION;
140144 server_hello_message.protocol = client_hello_message->protocol & _config.allowed_protocols ;
141145 }
142146
143- tcp_connection->send_message (reinterpret_cast <uint8_t *>(&server_hello_message), sizeof (server_hello_message));
147+ tcp_connection->send_message (
148+ reinterpret_cast <uint8_t *>(&server_hello_message),
149+ sizeof (server_hello_message));
144150
145- if (server_hello_message.version == 0 ) {
151+ if (server_hello_message.version == 0 )
152+ {
146153 THROW_EXCEPTION (ProtocolError, " Protocol version mismatch" );
147154 }
148155
149- if ((server_hello_message.protocol & Protocol::TLS) == Protocol::TLS) {
156+ if ((server_hello_message.protocol & Protocol::TLS) == Protocol::TLS)
157+ {
150158 auto tcp_socket = tcp_connection->release_socket ();
151159
152160 auto tls_socket = TLSSocket::create (std::move (tcp_socket), _ssl_ctx);
153161
154162 tls_socket->accept ();
155163
156- return std::unique_ptr<TLSConnection>(
157- new TLSConnection (std::move (tls_socket), _config.metrics ));
164+ return std::make_unique<TLSConnection>(std::move (tls_socket), _config.metrics );
158165 }
159- else if ((server_hello_message.protocol & Protocol::TCP) == Protocol::TCP) {
166+ else if ((server_hello_message.protocol & Protocol::TCP) == Protocol::TCP)
167+ {
168+ auto tcp_socket = tcp_connection->release_socket ();
160169 // Nothing to do, already using TCP
161- return tcp_connection;
170+ // return tcp_connection;
171+ return std::make_unique<TCPConnection>(std::move (tcp_socket), _config.metrics );
162172 }
163- else {
173+ else
174+ {
164175 THROW_EXCEPTION (ProtocolError, " Protocol negotiation failed" );
165176 }
166177}
178+
179+ std::unique_ptr<Connection> ConnServer::accept ()
180+ {
181+ auto connected_socket = TCPSocket::accept (_listening_socket);
182+
183+ auto tcp_connection = std::make_unique<TCPConnection>(
184+ std::move (connected_socket), _config.metrics );
185+
186+ return tcp_connection;
187+ }
0 commit comments