Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 8601ad9

Browse files
authored
Merge pull request #1218 from ChTimTsubasa/SSL_handshake_refactor
SSL handshake refactor
2 parents 8b8f68d + e174f9f commit 8601ad9

14 files changed

+631
-672
lines changed

src/include/common/internal_types.h

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ extern int TEST_TUPLES_PER_TILEGROUP;
7474
enum class CmpBool {
7575
FALSE = 0,
7676
TRUE = 1,
77-
NULL_ = 2 // Note the underscore suffix
77+
NULL_ = 2 // Note the underscore suffix
7878
};
7979

8080
//===--------------------------------------------------------------------===//
@@ -115,7 +115,6 @@ std::string PostgresValueTypeToString(PostgresValueType type);
115115
PostgresValueType StringToPostgresValueType(const std::string &str);
116116
std::ostream &operator<<(std::ostream &os, const PostgresValueType &type);
117117

118-
119118
//===--------------------------------------------------------------------===//
120119
// Predicate Expression Operation Types
121120
//===--------------------------------------------------------------------===//
@@ -631,10 +630,12 @@ std::string DropTypeToString(DropType type);
631630
DropType StringToDropType(const std::string &str);
632631
std::ostream &operator<<(std::ostream &os, const DropType &type);
633632

634-
template<class E> class EnumHash {
633+
template <class E>
634+
class EnumHash {
635635
public:
636-
size_t operator()(const E&e) const {
637-
return std::hash<typename std::underlying_type<E>::type>()(static_cast<typename std::underlying_type<E>::type>(e));
636+
size_t operator()(const E &e) const {
637+
return std::hash<typename std::underlying_type<E>::type>()(
638+
static_cast<typename std::underlying_type<E>::type>(e));
638639
}
639640
};
640641

@@ -665,9 +666,9 @@ enum class StatementType {
665666
ALTER = 12, // alter statement type
666667
TRANSACTION = 13, // transaction statement type,
667668
COPY = 14, // copy type
668-
ANALYZE = 15, // analyze type
669+
ANALYZE = 15, // analyze type
669670
VARIABLE_SET = 16, // variable set statement type
670-
CREATE_FUNC = 17, // create func statement type
671+
CREATE_FUNC = 17, // create func statement type
671672
};
672673
std::string StatementTypeToString(StatementType type);
673674
StatementType StringToStatementType(const std::string &str);
@@ -678,24 +679,24 @@ std::ostream &operator<<(std::ostream &os, const StatementType &type);
678679
//===--------------------------------------------------------------------===//
679680

680681
enum class QueryType {
681-
QUERY_BEGIN = 0, // begin query
682-
QUERY_COMMIT = 1, // commit query
683-
QUERY_ROLLBACK = 2, // rollback query
684-
QUERY_CREATE_TABLE = 3, // create query
682+
QUERY_BEGIN = 0, // begin query
683+
QUERY_COMMIT = 1, // commit query
684+
QUERY_ROLLBACK = 2, // rollback query
685+
QUERY_CREATE_TABLE = 3, // create query
685686
QUERY_CREATE_DB = 4,
686687
QUERY_CREATE_INDEX = 5,
687-
QUERY_DROP = 6, // other queries
688-
QUERY_INSERT = 7, // insert query
689-
QUERY_PREPARE = 8, // prepare query
690-
QUERY_EXECUTE = 9, // execute query
688+
QUERY_DROP = 6, // other queries
689+
QUERY_INSERT = 7, // insert query
690+
QUERY_PREPARE = 8, // prepare query
691+
QUERY_EXECUTE = 9, // execute query
691692
QUERY_UPDATE = 10,
692693
QUERY_DELETE = 11,
693694
QUERY_RENAME = 12,
694695
QUERY_ALTER = 13,
695696
QUERY_COPY = 14,
696697
QUERY_ANALYZE = 15,
697-
QUERY_SET = 16, // set query
698-
QUERY_SHOW = 17, // show query
698+
QUERY_SET = 16, // set query
699+
QUERY_SHOW = 17, // show query
699700
QUERY_SELECT = 18,
700701
QUERY_OTHER = 19,
701702
QUERY_INVALID = 20,
@@ -705,8 +706,11 @@ enum class QueryType {
705706
};
706707
std::string QueryTypeToString(QueryType query_type);
707708
QueryType StringToQueryType(std::string str);
708-
namespace parser{ class SQLStatement;}
709-
QueryType StatementTypeToQueryType(StatementType stmt_type, const parser::SQLStatement* sql_stmt);
709+
namespace parser {
710+
class SQLStatement;
711+
}
712+
QueryType StatementTypeToQueryType(StatementType stmt_type,
713+
const parser::SQLStatement *sql_stmt);
710714
//===--------------------------------------------------------------------===//
711715
// Scan Direction Types
712716
//===--------------------------------------------------------------------===//
@@ -1014,7 +1018,8 @@ const int TRIGGER_TYPE_MAX = TRIGGER_TYPE_ROW | TRIGGER_TYPE_STATEMENT |
10141018

10151019
// Statistics Collection Type
10161020
// Disable or enable
1017-
// TODO: This should probably be a collection level and not a boolean (enable/disable)
1021+
// TODO: This should probably be a collection level and not a boolean
1022+
// (enable/disable)
10181023
enum class StatsType {
10191024
// Disable statistics collection
10201025
INVALID = INVALID_TYPE_ID,
@@ -1093,10 +1098,7 @@ static const int INVALID_FILE_DESCRIPTOR = -1;
10931098
// Tuple serialization formats
10941099
// ------------------------------------------------------------------
10951100

1096-
enum class TupleSerializationFormat {
1097-
NATIVE = 0,
1098-
DR = 1
1099-
};
1101+
enum class TupleSerializationFormat { NATIVE = 0, DR = 1 };
11001102

11011103
// ------------------------------------------------------------------
11021104
// Entity types
@@ -1201,7 +1203,7 @@ std::ostream &operator<<(std::ostream &os, const RWType &type);
12011203

12021204
// ItemPointer -> type
12031205
typedef CuckooMap<ItemPointer, RWType, ItemPointerHasher, ItemPointerComparator>
1204-
ReadWriteSet;
1206+
ReadWriteSet;
12051207

12061208
// this enum is to identify why the version should be GC'd.
12071209
enum class GCVersionType {
@@ -1397,15 +1399,15 @@ enum class ProcessResult {
13971399
COMPLETE,
13981400
TERMINATE,
13991401
PROCESSING,
1400-
MORE_DATA_REQUIRED
1402+
MORE_DATA_REQUIRED,
1403+
NEED_SSL_HANDSHAKE,
14011404
};
14021405

14031406
enum class NetworkProtocolType {
14041407
POSTGRES_JDBC,
14051408
POSTGRES_PSQL,
14061409
};
14071410

1408-
14091411
enum class SSLLevel {
14101412
SSL_DISABLE = 0,
14111413
SSL_PREFER = 1,

src/include/network/connection_handle.h

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class ConnectionHandle {
8787
Transition ProcessWrite();
8888
Transition GetResult();
8989
Transition CloseSocket();
90+
/**
91+
* Flush out all the responses and do real SSL handshake
92+
*/
93+
Transition ProcessWrite_SSLHandshake();
9094

9195
private:
9296
/**
@@ -144,16 +148,7 @@ class ConnectionHandle {
144148
friend class ConnectionHandleFactory;
145149

146150
ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler,
147-
std::shared_ptr<Buffer> rbuf, std::shared_ptr<Buffer> wbuf,
148-
bool ssl_able);
149-
150-
ProcessResult ProcessInitial();
151-
152-
/**
153-
* Extracts the header of a Postgres start up packet from the read socket
154-
* buffer
155-
*/
156-
static bool ReadStartupPacketHeader(Buffer &rbuf, InputPacket &rpkt);
151+
std::shared_ptr<Buffer> rbuf, std::shared_ptr<Buffer> wbuf);
157152

158153
/**
159154
* Writes a packet's header (type, size) into the write buffer
@@ -171,6 +166,16 @@ class ConnectionHandle {
171166
*/
172167
WriteState FlushWriteBuffer();
173168

169+
/**
170+
* @brief: process SSL handshake to generate valid SSL
171+
* connection context for further communications
172+
* @return FINISH when the SSL handshake failed
173+
* PROCEED when the SSL handshake success
174+
* NEED_DATA when the SSL handshake is partially done due to network
175+
* latency
176+
*/
177+
Transition SSLHandshake();
178+
174179
/**
175180
* Set the socket to non-blocking mode
176181
*/
@@ -190,6 +195,16 @@ class ConnectionHandle {
190195
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof one);
191196
}
192197

198+
/**
199+
* @brief: Determine if there is still responses in the buffer
200+
* @return true if there is still responses to flush out in either wbuf or
201+
* responses
202+
*/
203+
inline bool HasResponse() {
204+
return (protocol_handler_->responses_.size() != 0) ||
205+
(wbuf_->buf_size != 0);
206+
}
207+
193208
int sock_fd_; // socket file descriptor
194209
struct event *network_event = nullptr; // something to read from network
195210
struct event *workpool_event = nullptr; // worker thread done the job
@@ -204,16 +219,8 @@ class ConnectionHandle {
204219
std::shared_ptr<Buffer> rbuf_; // Socket's read buffer
205220
std::shared_ptr<Buffer> wbuf_; // Socket's write buffer
206221
unsigned int next_response_ = 0; // The next response in the response buffer
207-
Client client_;
208-
StateMachine state_machine_;
209222

210-
// TODO(Tianyi) Can we encapsulate these flags?
211-
bool ssl_handshake_ = false;
212-
bool finish_startup_packet_ = false;
213-
bool ssl_able_;
214-
215-
// TODO(Tianyi) hide this in protocol handler
216-
InputPacket initial_packet_;
223+
StateMachine state_machine_;
217224

218225
short curr_event_flag_; // current libevent event flag
219226
};

src/include/network/connection_handle_factory.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,21 @@ class ConnectionHandleFactory {
4040
// (probably also to-do: beat up the person who wrote this)
4141
PelotonServer::recent_connfd = conn_fd;
4242
auto it = reusable_handles_.find(conn_fd);
43-
bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE);
4443
if (it == reusable_handles_.end()) {
4544
// We are not using std::make_shared here because we want to keep
4645
// ConnectionHandle constructor
4746
// private to avoid unintentional use.
4847
auto handle = std::shared_ptr<ConnectionHandle>(
4948
new ConnectionHandle(conn_fd, handler, std::make_shared<Buffer>(),
50-
std::make_shared<Buffer>(), ssl_able));
49+
std::make_shared<Buffer>()));
5150
reusable_handles_[conn_fd] = handle;
5251
return handle;
5352
}
5453

5554
it->second->rbuf_->Reset();
5655
it->second->wbuf_->Reset();
5756
std::shared_ptr<ConnectionHandle> new_handle(new ConnectionHandle(
58-
conn_fd, handler, it->second->rbuf_, it->second->wbuf_, ssl_able));
57+
conn_fd, handler, it->second->rbuf_, it->second->wbuf_));
5958
reusable_handles_[conn_fd] = new_handle;
6059
return new_handle;
6160
}

src/include/network/marshal.h

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
#include <string>
1616
#include <vector>
1717

18+
#include "common/internal_types.h"
1819
#include "common/logger.h"
1920
#include "common/macros.h"
20-
#include "common/internal_types.h"
2121

2222
#define BUFFER_INIT_SIZE 100
2323

@@ -138,6 +138,8 @@ struct OutputPacket {
138138
size_t ptr; // ByteBuf cursor, which is used for get and put
139139
NetworkMessageType msg_type; // header
140140

141+
bool single_type_pkt; // there would be only a pkt type being written to the
142+
// buffer when this flag is true
141143
bool skip_header_write; // whether we should write header to socket wbuf
142144
size_t write_ptr; // cursor used to write packet content to socket wbuf
143145

@@ -146,28 +148,16 @@ struct OutputPacket {
146148
buf.resize(BUFFER_INIT_SIZE);
147149
buf.shrink_to_fit();
148150
buf.clear();
151+
single_type_pkt = false;
149152
len = ptr = write_ptr = 0;
150153
msg_type = NetworkMessageType::NULL_COMMAND;
151154
skip_header_write = true;
152155
}
153156
};
154157

155-
struct Client {
156-
// Authentication details
157-
std::string dbname;
158-
std::string user;
159-
std::unordered_map<std::string, std::string> cmdline_options;
160-
161-
inline void Reset() {
162-
dbname.clear();
163-
user.clear();
164-
cmdline_options.clear();
165-
}
166-
};
167-
168158
/*
169-
* Marshallers
170-
*/
159+
* Marshallers
160+
*/
171161

172162
/* packet_put_byte - used to write a single byte into a packet */
173163
extern void PacketPutByte(OutputPacket *pkt, const uchar c);
@@ -186,22 +176,22 @@ extern void PacketPutCbytes(OutputPacket *pkt, const uchar *b, int len);
186176
extern void PacketPutString(OutputPacket *pkt, const std::string &data);
187177

188178
/*
189-
* Unmarshallers
190-
*/
179+
* Unmarshallers
180+
*/
191181

192182
/* Copy len bytes from the position indicated by begin to an array */
193183
extern uchar *PacketCopyBytes(ByteBuf::const_iterator begin, int len);
194184
/*
195-
* packet_get_int - Parse an int out of the head of the
196-
* packet. "base" bytes determine the number of bytes of integer
197-
* we are parsing out.
198-
*/
185+
* packet_get_int - Parse an int out of the head of the
186+
* packet. "base" bytes determine the number of bytes of integer
187+
* we are parsing out.
188+
*/
199189
extern int PacketGetInt(InputPacket *pkt, uchar base);
200190

201191
/*
202-
* packet_get_string - parse out a string of size len.
203-
* if len=0? parse till the end of the string
204-
*/
192+
* packet_get_string - parse out a string of size len.
193+
* if len=0? parse till the end of the string
194+
*/
205195
extern void PacketGetString(InputPacket *pkt, size_t len, std::string &result);
206196

207197
/* packet_get_bytes - Parse out "len" bytes of pkt as raw bytes */
@@ -211,9 +201,9 @@ extern void PacketGetBytes(InputPacket *pkt, size_t len, ByteBuf &result);
211201
extern void PacketGetByte(InputPacket *rpkt, uchar &result);
212202

213203
/*
214-
* get_string_token - used to extract a string token
215-
* from an unsigned char vector
216-
*/
204+
* get_string_token - used to extract a string token
205+
* from an unsigned char vector
206+
*/
217207
extern void GetStringToken(InputPacket *pkt, std::string &result);
218208

219209
} // namespace network

src/include/network/network_state.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ namespace network {
2121
enum class ConnState {
2222
READ, // State that reads data from the network
2323
WRITE, // State the writes data to the network
24-
WAIT, // State for waiting for some event to happen
2524
PROCESS, // State that runs the network protocol on received data
2625
CLOSING, // State for closing the client connection
2726
GET_RESULT, // State when triggered by worker thread that completes the task.
27+
PROCESS_WRITE_SSL_HANDSHAKE, // State to flush out responses and doing (Real)
28+
// SSL handshake
2829
};
2930

3031
// TODO(tianyu): Convert use cases of this to just return Transition
@@ -46,7 +47,8 @@ enum class Transition {
4647
// TODO(tianyu) generalize this symbol, this is currently only used in process
4748
GET_RESULT,
4849
FINISH,
49-
RETRY
50+
RETRY,
51+
NEED_SSL_HANDSHAKE,
5052
};
51-
}
52-
}
53+
} // namespace network
54+
} // namespace peloton

0 commit comments

Comments
 (0)