|
7 | 7 |
|
8 | 8 | #include <functional>
|
9 | 9 |
|
| 10 | +#include <generated/security/plaintext/protobuf/plaintext.pb.h> |
| 11 | +#include <libp2p/basic/protobuf_message_read_writer.hpp> |
10 | 12 | #include <libp2p/peer/peer_id.hpp>
|
11 | 13 | #include <libp2p/security/error.hpp>
|
12 | 14 | #include <libp2p/security/plaintext/plaintext_connection.hpp>
|
13 | 15 |
|
14 | 16 | #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
|
15 |
| -# pragma GCC diagnostic ignored "-Wparentheses" |
| 17 | +#pragma GCC diagnostic ignored "-Wparentheses" |
16 | 18 | #endif
|
17 | 19 |
|
18 |
| -#define PLAINTEXT_OUTCOME_TRY(name, res, conn, cb) \ |
19 |
| - auto(name) = (res); \ |
20 |
| - if ((name).has_error()) { \ |
21 |
| - closeConnection(conn, (name).error()); \ |
22 |
| - cb((name).error()); \ |
23 |
| - return; \ |
| 20 | +#define PLAINTEXT_OUTCOME_VOID_TRY(res, conn, cb) \ |
| 21 | + if ((res).has_error()) { \ |
| 22 | + closeConnection(conn, (res).error()); \ |
| 23 | + cb((res).error()); \ |
| 24 | + return; \ |
24 | 25 | }
|
25 | 26 |
|
| 27 | +#define PLAINTEXT_OUTCOME_TRY(name, res, conn, cb) \ |
| 28 | + PLAINTEXT_OUTCOME_VOID_TRY((res), (conn), (cb)) \ |
| 29 | + auto(name) = (res).value(); |
| 30 | + |
26 | 31 | OUTCOME_CPP_DEFINE_CATEGORY(libp2p::security, Plaintext::Error, e) {
|
27 | 32 | using E = libp2p::security::Plaintext::Error;
|
28 | 33 | switch (e) {
|
29 | 34 | case E::EXCHANGE_SEND_ERROR:
|
30 | 35 | return "Error occured while sending Exchange message to the peer";
|
31 | 36 | case E::EXCHANGE_RECEIVE_ERROR:
|
32 |
| - return "Error occured while receiving Exchange message to the peer"; |
| 37 | + return "Error occurred while receiving Exchange message to the peer"; |
33 | 38 | case E::INVALID_PEER_ID:
|
34 | 39 | return "Received peer id doesn't match actual peer id";
|
35 | 40 | case E::EMPTY_PEER_ID:
|
@@ -61,93 +66,74 @@ namespace libp2p::security {
|
61 | 66 | std::shared_ptr<connection::RawConnection> inbound,
|
62 | 67 | SecConnCallbackFunc cb) {
|
63 | 68 | log_->debug("securing inbound connection");
|
64 |
| - sendExchangeMsg(inbound, cb); |
65 |
| - receiveExchangeMsg(inbound, boost::none, cb); |
| 69 | + auto rw = std::make_shared<basic::ProtobufMessageReadWriter>(inbound); |
| 70 | + sendExchangeMsg(inbound, rw, cb); |
| 71 | + receiveExchangeMsg(inbound, rw, boost::none, cb); |
66 | 72 | }
|
67 | 73 |
|
68 | 74 | void Plaintext::secureOutbound(
|
69 | 75 | std::shared_ptr<connection::RawConnection> outbound,
|
70 | 76 | const peer::PeerId &p, SecConnCallbackFunc cb) {
|
71 | 77 | log_->debug("securing outbound connection");
|
72 |
| - sendExchangeMsg(outbound, cb); |
73 |
| - receiveExchangeMsg(outbound, p, cb); |
| 78 | + auto rw = std::make_shared<basic::ProtobufMessageReadWriter>(outbound); |
| 79 | + sendExchangeMsg(outbound, rw, cb); |
| 80 | + receiveExchangeMsg(outbound, rw, p, cb); |
74 | 81 | }
|
75 | 82 |
|
76 | 83 | void Plaintext::sendExchangeMsg(
|
77 | 84 | const std::shared_ptr<connection::RawConnection> &conn,
|
| 85 | + const std::shared_ptr<basic::ProtobufMessageReadWriter> &rw, |
78 | 86 | SecConnCallbackFunc cb) const {
|
79 |
| - PLAINTEXT_OUTCOME_TRY(out_msg_res, |
80 |
| - marshaller_->marshal(plaintext::ExchangeMessage{ |
81 |
| - .pubkey = idmgr_->getKeyPair().publicKey, |
82 |
| - .peer_id = idmgr_->getId()}), |
83 |
| - conn, cb) |
84 |
| - |
85 |
| - auto out_msg = out_msg_res.value(); |
86 |
| - auto len = out_msg.size(); |
87 |
| - |
88 |
| - std::vector<uint8_t> len_bytes = { |
89 |
| - static_cast<uint8_t>(len >> 24u), static_cast<uint8_t>(len >> 16u), |
90 |
| - static_cast<uint8_t>(len >> 8u), static_cast<uint8_t>(len)}; |
91 |
| - |
92 |
| - conn->write(len_bytes, 4, |
93 |
| - [self{shared_from_this()}, out_msg, conn, |
94 |
| - cb{std::move(cb)}](auto &&res) mutable { |
95 |
| - if (res.has_error()) { |
96 |
| - self->closeConnection(conn, Error::EXCHANGE_SEND_ERROR); |
97 |
| - return cb(Error::EXCHANGE_SEND_ERROR); |
98 |
| - } |
99 |
| - |
100 |
| - conn->write( |
101 |
| - out_msg, out_msg.size(), |
102 |
| - [self{std::move(self)}, cb{cb}, conn](auto &&res) { |
103 |
| - if (res.has_error()) { |
104 |
| - self->closeConnection(conn, |
105 |
| - Error::EXCHANGE_SEND_ERROR); |
106 |
| - return cb(Error::EXCHANGE_SEND_ERROR); |
107 |
| - } |
108 |
| - }); |
109 |
| - }); |
| 87 | + plaintext::ExchangeMessage exchange_msg{ |
| 88 | + .pubkey = idmgr_->getKeyPair().publicKey, .peer_id = idmgr_->getId()}; |
| 89 | + PLAINTEXT_OUTCOME_TRY(proto_exchange_msg, |
| 90 | + marshaller_->handyToProto(exchange_msg), conn, cb) |
| 91 | + |
| 92 | + rw->write<plaintext::protobuf::Exchange>( |
| 93 | + proto_exchange_msg, |
| 94 | + [self{shared_from_this()}, cb{std::move(cb)}, conn](auto &&res) { |
| 95 | + if (res.has_error()) { |
| 96 | + self->closeConnection(conn, Error::EXCHANGE_SEND_ERROR); |
| 97 | + return cb(Error::EXCHANGE_SEND_ERROR); |
| 98 | + } |
| 99 | + }); |
110 | 100 | }
|
111 | 101 |
|
112 | 102 | void Plaintext::receiveExchangeMsg(
|
113 | 103 | const std::shared_ptr<connection::RawConnection> &conn,
|
| 104 | + const std::shared_ptr<basic::ProtobufMessageReadWriter> &rw, |
114 | 105 | const MaybePeerId &p, SecConnCallbackFunc cb) const {
|
115 |
| - constexpr size_t kMaxMsgSize = 4; // we read uint32_t first |
116 |
| - auto read_bytes = std::make_shared<std::vector<uint8_t>>(kMaxMsgSize); |
117 |
| - |
118 |
| - conn->read( |
119 |
| - *read_bytes, kMaxMsgSize, |
| 106 | + auto remote_peer_exchange_bytes = std::make_shared<std::vector<uint8_t>>(); |
| 107 | + rw->read<plaintext::protobuf::Exchange>( |
120 | 108 | [self{shared_from_this()}, conn, p, cb{std::move(cb)},
|
121 |
| - read_bytes](auto &&r) { |
122 |
| - auto bytes_size = (static_cast<uint32_t>(read_bytes->at(0)) << 24u) |
123 |
| - + (static_cast<uint32_t>(read_bytes->at(1)) << 16u) |
124 |
| - + (static_cast<uint32_t>(read_bytes->at(2)) << 8u) |
125 |
| - + read_bytes->at(3); |
126 |
| - |
127 |
| - auto received_bytes = |
128 |
| - std::make_shared<std::vector<uint8_t>>(bytes_size); |
129 |
| - conn->read(*received_bytes, received_bytes->size(), |
130 |
| - [self, conn, p, cb, received_bytes](auto &&r) { |
131 |
| - self->readCallback(conn, p, cb, received_bytes, r); |
132 |
| - }); |
133 |
| - }); |
| 109 | + remote_peer_exchange_bytes](auto &&res) { |
| 110 | + if (!res) { |
| 111 | + return cb(res.error()); |
| 112 | + } |
| 113 | + self->readCallback(conn, p, cb, remote_peer_exchange_bytes, |
| 114 | + remote_peer_exchange_bytes->size()); |
| 115 | + }, |
| 116 | + remote_peer_exchange_bytes); |
134 | 117 | }
|
135 | 118 |
|
136 | 119 | void Plaintext::readCallback(
|
137 | 120 | const std::shared_ptr<connection::RawConnection> &conn,
|
138 | 121 | const MaybePeerId &p, const SecConnCallbackFunc &cb,
|
139 | 122 | const std::shared_ptr<std::vector<uint8_t>> &read_bytes,
|
140 | 123 | outcome::result<size_t> read_call_res) const {
|
141 |
| - PLAINTEXT_OUTCOME_TRY(r, read_call_res, conn, cb); |
| 124 | + /* |
| 125 | + * The method does redundant unmarshalling of message bytes to proto |
| 126 | + * exchange message. This could be a subject of further improvement. |
| 127 | + */ |
| 128 | + PLAINTEXT_OUTCOME_VOID_TRY(read_call_res, conn, cb); |
142 | 129 | PLAINTEXT_OUTCOME_TRY(in_exchange_msg, marshaller_->unmarshal(*read_bytes),
|
143 | 130 | conn, cb);
|
144 |
| - auto &msg = in_exchange_msg.value().first; |
| 131 | + auto &msg = in_exchange_msg.first; |
145 | 132 | auto received_pid = msg.peer_id;
|
146 | 133 | auto pkey = msg.pubkey;
|
147 | 134 |
|
148 | 135 | // PeerId is derived from the Protobuf-serialized public key, not a raw one
|
149 |
| - auto derived_pid_res = |
150 |
| - peer::PeerId::fromPublicKey(in_exchange_msg.value().second); |
| 136 | + auto derived_pid_res = peer::PeerId::fromPublicKey(in_exchange_msg.second); |
151 | 137 | if (!derived_pid_res) {
|
152 | 138 | log_->error("cannot create a PeerId from the received public key: {}",
|
153 | 139 | derived_pid_res.error().message());
|
|
0 commit comments