Skip to content

Commit 6af9b31

Browse files
committed
Merge #19107: p2p: Move all header verification into the network layer, extend logging
deb5271 Remove header checks out of net_processing (Troy Giorshev) 52d4ae4 Give V1TransportDeserializer CChainParams& member (Troy Giorshev) 5bceef6 Change CMessageHeader Constructor (Troy Giorshev) 1ca20c1 Add doxygen comment for ReceiveMsgBytes (Troy Giorshev) 890b1d7 Move checksum check from net_processing to net (Troy Giorshev) 2716647 Give V1TransportDeserializer an m_node_id member (Troy Giorshev) Pull request description: Inspired by #15206 and #15197, this PR moves all message header verification from the message processing layer and into the network/transport layer. In the previous PRs there is a change in behavior, where we would disconnect from peers upon a single failed checksum check. In various discussions there was concern over whether this was the right choice, and some expressed a desire to see how this would look if it was made to be a pure refactor. For more context, see https://bitcoincore.reviews/15206.html#l-81. This PR improves the separation between the p2p layers, helping improvements like [BIP324](bitcoin/bitcoin#18242) and #18989. ACKs for top commit: ryanofsky: Code review ACK deb5271 just rebase due to conflict on adjacent line jnewbery: Code review ACK deb5271. Tree-SHA512: 1a3b7ae883b020cfee1bef968813e04df651ffdad9dd961a826bd80654f2c98676ce7f4721038a1b78d8790e4cebe8060419e3d8affc97ce2b9b4e4b72e6fa9f
2 parents e36aa35 + deb5271 commit 6af9b31

File tree

8 files changed

+101
-111
lines changed

8 files changed

+101
-111
lines changed

src/net.cpp

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <net.h>
1111

1212
#include <banman.h>
13-
#include <chainparams.h>
1413
#include <clientversion.h>
1514
#include <consensus/consensus.h>
1615
#include <crypto/sha256.h>
@@ -607,6 +606,16 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap)
607606
}
608607
#undef X
609608

609+
/**
610+
* Receive bytes from the buffer and deserialize them into messages.
611+
*
612+
* @param[in] pch A pointer to the raw data
613+
* @param[in] nBytes Size of the data
614+
* @param[out] complete Set True if at least one message has been
615+
* deserialized and is ready to be processed
616+
* @return True if the peer should stay connected,
617+
* False if the peer should be disconnected from.
618+
*/
610619
bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete)
611620
{
612621
complete = false;
@@ -617,25 +626,35 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
617626
while (nBytes > 0) {
618627
// absorb network data
619628
int handled = m_deserializer->Read(pch, nBytes);
620-
if (handled < 0) return false;
629+
if (handled < 0) {
630+
// Serious header problem, disconnect from the peer.
631+
return false;
632+
}
621633

622634
pch += handled;
623635
nBytes -= handled;
624636

625637
if (m_deserializer->Complete()) {
626638
// decompose a transport agnostic CNetMessage from the deserializer
627-
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time);
639+
uint32_t out_err_raw_size{0};
640+
Optional<CNetMessage> result{m_deserializer->GetMessage(time, out_err_raw_size)};
641+
if (!result) {
642+
// Message deserialization failed. Drop the message but don't disconnect the peer.
643+
// store the size of the corrupt message
644+
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
645+
continue;
646+
}
628647

629648
//store received bytes per message command
630649
//to prevent a memory DOS, only allow valid commands
631-
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
650+
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command);
632651
if (i == mapRecvBytesPerMsgCmd.end())
633652
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
634653
assert(i != mapRecvBytesPerMsgCmd.end());
635-
i->second += msg.m_raw_message_size;
654+
i->second += result->m_raw_message_size;
636655

637656
// push the message to the process queue,
638-
vRecvMsg.push_back(std::move(msg));
657+
vRecvMsg.push_back(std::move(*result));
639658

640659
complete = true;
641660
}
@@ -662,11 +681,19 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
662681
hdrbuf >> hdr;
663682
}
664683
catch (const std::exception&) {
684+
LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", m_node_id);
685+
return -1;
686+
}
687+
688+
// Check start string, network magic
689+
if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) {
690+
LogPrint(BCLog::NET, "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id);
665691
return -1;
666692
}
667693

668694
// reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH
669695
if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
696+
LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, m_node_id);
670697
return -1;
671698
}
672699

@@ -701,36 +728,39 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
701728
return data_hash;
702729
}
703730

704-
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time)
731+
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, uint32_t& out_err_raw_size)
705732
{
706733
// decompose a single CNetMessage from the TransportDeserializer
707-
CNetMessage msg(std::move(vRecv));
734+
Optional<CNetMessage> msg(std::move(vRecv));
708735

709-
// store state about valid header, netmagic and checksum
710-
msg.m_valid_header = hdr.IsValid(message_start);
711-
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
712-
uint256 hash = GetMessageHash();
736+
// store command string, time, and sizes
737+
msg->m_command = hdr.GetCommand();
738+
msg->m_time = time;
739+
msg->m_message_size = hdr.nMessageSize;
740+
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
713741

714-
// store command string, payload size
715-
msg.m_command = hdr.GetCommand();
716-
msg.m_message_size = hdr.nMessageSize;
717-
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
742+
uint256 hash = GetMessageHash();
718743

719744
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
720745
RandAddEvent(ReadLE32(hash.begin()));
721746

722-
msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0);
723-
if (!msg.m_valid_checksum) {
724-
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n",
725-
SanitizeString(msg.m_command), msg.m_message_size,
747+
// Check checksum and header command string
748+
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
749+
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
750+
SanitizeString(msg->m_command), msg->m_message_size,
726751
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
727-
HexStr(hdr.pchChecksum));
728-
}
729-
730-
// store receive time
731-
msg.m_time = time;
732-
733-
// reset the network deserializer (prepare for the next message)
752+
HexStr(hdr.pchChecksum),
753+
m_node_id);
754+
out_err_raw_size = msg->m_raw_message_size;
755+
msg = nullopt;
756+
} else if (!hdr.IsCommandValid()) {
757+
LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n",
758+
hdr.GetCommand(), msg->m_message_size, m_node_id);
759+
out_err_raw_size = msg->m_raw_message_size;
760+
msg = nullopt;
761+
}
762+
763+
// Always reset the network deserializer (prepare for the next message)
734764
Reset();
735765
return msg;
736766
}
@@ -2850,7 +2880,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
28502880
LogPrint(BCLog::NET, "Added connection peer=%d\n", id);
28512881
}
28522882

2853-
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
2883+
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION));
28542884
m_serializer = MakeUnique<V1TransportSerializer>(V1TransportSerializer());
28552885
}
28562886

src/net.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
#include <addrman.h>
1111
#include <amount.h>
1212
#include <bloom.h>
13+
#include <chainparams.h>
1314
#include <compat.h>
1415
#include <crypto/siphash.h>
1516
#include <hash.h>
1617
#include <limitedmap.h>
17-
#include <netaddress.h>
1818
#include <net_permissions.h>
19+
#include <netaddress.h>
20+
#include <optional.h>
1921
#include <policy/feerate.h>
2022
#include <protocol.h>
2123
#include <random.h>
@@ -712,11 +714,8 @@ class CNetMessage {
712714
public:
713715
CDataStream m_recv; //!< received message data
714716
std::chrono::microseconds m_time{0}; //!< time of message receipt
715-
bool m_valid_netmagic = false;
716-
bool m_valid_header = false;
717-
bool m_valid_checksum = false;
718-
uint32_t m_message_size{0}; //!< size of the payload
719-
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
717+
uint32_t m_message_size{0}; //!< size of the payload
718+
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
720719
std::string m_command;
721720

722721
CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {}
@@ -740,13 +739,15 @@ class TransportDeserializer {
740739
// read and deserialize data
741740
virtual int Read(const char *data, unsigned int bytes) = 0;
742741
// decomposes a message from the context
743-
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0;
742+
virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0;
744743
virtual ~TransportDeserializer() {}
745744
};
746745

747746
class V1TransportDeserializer final : public TransportDeserializer
748747
{
749748
private:
749+
const CChainParams& m_chain_params;
750+
const NodeId m_node_id; // Only for logging
750751
mutable CHash256 hasher;
751752
mutable uint256 data_hash;
752753
bool in_data; // parsing header (false) or data (true)
@@ -772,8 +773,12 @@ class V1TransportDeserializer final : public TransportDeserializer
772773
}
773774

774775
public:
775-
776-
V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
776+
V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
777+
: m_chain_params(chain_params),
778+
m_node_id(node_id),
779+
hdrbuf(nTypeIn, nVersionIn),
780+
vRecv(nTypeIn, nVersionIn)
781+
{
777782
Reset();
778783
}
779784

@@ -793,7 +798,7 @@ class V1TransportDeserializer final : public TransportDeserializer
793798
if (ret < 0) Reset();
794799
return ret;
795800
}
796-
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override;
801+
Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
797802
};
798803

799804
/** The TransportSerializer prepares messages for the network transport

src/net_processing.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3815,14 +3815,6 @@ bool PeerManager::MaybeDiscourageAndDisconnect(CNode& pnode)
38153815

38163816
bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgProc)
38173817
{
3818-
//
3819-
// Message format
3820-
// (4) message start
3821-
// (12) command
3822-
// (4) size
3823-
// (4) checksum
3824-
// (x) data
3825-
//
38263818
bool fMoreWork = false;
38273819

38283820
if (!pfrom->vRecvGetData.empty())
@@ -3863,35 +3855,13 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
38633855
CNetMessage& msg(msgs.front());
38643856

38653857
msg.SetVersion(pfrom->GetCommonVersion());
3866-
// Check network magic
3867-
if (!msg.m_valid_netmagic) {
3868-
LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
3869-
pfrom->fDisconnect = true;
3870-
return false;
3871-
}
3872-
3873-
// Check header
3874-
if (!msg.m_valid_header)
3875-
{
3876-
LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
3877-
return fMoreWork;
3878-
}
38793858
const std::string& msg_type = msg.m_command;
38803859

38813860
// Message size
38823861
unsigned int nMessageSize = msg.m_message_size;
38833862

3884-
// Checksum
3885-
CDataStream& vRecv = msg.m_recv;
3886-
if (!msg.m_valid_checksum)
3887-
{
3888-
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
3889-
SanitizeString(msg_type), nMessageSize, pfrom->GetId());
3890-
return fMoreWork;
3891-
}
3892-
38933863
try {
3894-
ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc);
3864+
ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc);
38953865
if (interruptMsgProc)
38963866
return false;
38973867
if (!pfrom->vRecvGetData.empty())

src/protocol.cpp

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ const static std::string allNetMessageTypes[] = {
8484
};
8585
const static std::vector<std::string> allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes+ARRAYLEN(allNetMessageTypes));
8686

87-
CMessageHeader::CMessageHeader(const MessageStartChars& pchMessageStartIn)
87+
CMessageHeader::CMessageHeader()
8888
{
89-
memcpy(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE);
89+
memset(pchMessageStart, 0, MESSAGE_START_SIZE);
9090
memset(pchCommand, 0, sizeof(pchCommand));
9191
nMessageSize = -1;
9292
memset(pchChecksum, 0, CHECKSUM_SIZE);
@@ -111,31 +111,20 @@ std::string CMessageHeader::GetCommand() const
111111
return std::string(pchCommand, pchCommand + strnlen(pchCommand, COMMAND_SIZE));
112112
}
113113

114-
bool CMessageHeader::IsValid(const MessageStartChars& pchMessageStartIn) const
114+
bool CMessageHeader::IsCommandValid() const
115115
{
116-
// Check start string
117-
if (memcmp(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE) != 0)
118-
return false;
119-
120116
// Check the command string for errors
121-
for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; p1++)
122-
{
123-
if (*p1 == 0)
124-
{
117+
for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; ++p1) {
118+
if (*p1 == 0) {
125119
// Must be all zeros after the first zero
126-
for (; p1 < pchCommand + COMMAND_SIZE; p1++)
127-
if (*p1 != 0)
120+
for (; p1 < pchCommand + COMMAND_SIZE; ++p1) {
121+
if (*p1 != 0) {
128122
return false;
129-
}
130-
else if (*p1 < ' ' || *p1 > 0x7E)
123+
}
124+
}
125+
} else if (*p1 < ' ' || *p1 > 0x7E) {
131126
return false;
132-
}
133-
134-
// Message size
135-
if (nMessageSize > MAX_SIZE)
136-
{
137-
LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) nMessageSize > MAX_SIZE\n", GetCommand(), nMessageSize);
138-
return false;
127+
}
139128
}
140129

141130
return true;

src/protocol.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ class CMessageHeader
3737
static constexpr size_t HEADER_SIZE = MESSAGE_START_SIZE + COMMAND_SIZE + MESSAGE_SIZE_SIZE + CHECKSUM_SIZE;
3838
typedef unsigned char MessageStartChars[MESSAGE_START_SIZE];
3939

40-
explicit CMessageHeader(const MessageStartChars& pchMessageStartIn);
40+
explicit CMessageHeader();
4141

4242
/** Construct a P2P message header from message-start characters, a command and the size of the message.
4343
* @note Passing in a `pszCommand` longer than COMMAND_SIZE will result in a run-time assertion error.
4444
*/
4545
CMessageHeader(const MessageStartChars& pchMessageStartIn, const char* pszCommand, unsigned int nMessageSizeIn);
4646

4747
std::string GetCommand() const;
48-
bool IsValid(const MessageStartChars& messageStart) const;
48+
bool IsCommandValid() const;
4949

5050
SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); }
5151

src/test/fuzz/deserialize.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,9 @@ void test_one_input(const std::vector<uint8_t>& buffer)
189189
DeserializeFromFuzzingInput(buffer, s);
190190
AssertEqualAfterSerializeDeserialize(s);
191191
#elif MESSAGEHEADER_DESERIALIZE
192-
const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00};
193-
CMessageHeader mh(pchMessageStart);
192+
CMessageHeader mh;
194193
DeserializeFromFuzzingInput(buffer, mh);
195-
(void)mh.IsValid(pchMessageStart);
194+
(void)mh.IsCommandValid();
196195
#elif ADDRESS_DESERIALIZE
197196
CAddress a;
198197
DeserializeFromFuzzingInput(buffer, a);

src/test/fuzz/p2p_transport_deserializer.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ void initialize()
1919

2020
void test_one_input(const std::vector<uint8_t>& buffer)
2121
{
22-
V1TransportDeserializer deserializer{Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION};
22+
// Construct deserializer, with a dummy NodeId
23+
V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
2324
const char* pch = (const char*)buffer.data();
2425
size_t n_bytes = buffer.size();
2526
while (n_bytes > 0) {
@@ -31,16 +32,13 @@ void test_one_input(const std::vector<uint8_t>& buffer)
3132
n_bytes -= handled;
3233
if (deserializer.Complete()) {
3334
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
34-
const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time);
35-
assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE);
36-
assert(msg.m_raw_message_size <= buffer.size());
37-
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
38-
assert(msg.m_time == m_time);
39-
if (msg.m_valid_header) {
40-
assert(msg.m_valid_netmagic);
41-
}
42-
if (!msg.m_valid_netmagic) {
43-
assert(!msg.m_valid_header);
35+
uint32_t out_err_raw_size{0};
36+
Optional<CNetMessage> result{deserializer.GetMessage(m_time, out_err_raw_size)};
37+
if (result) {
38+
assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
39+
assert(result->m_raw_message_size <= buffer.size());
40+
assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
41+
assert(result->m_time == m_time);
4442
}
4543
}
4644
}

0 commit comments

Comments
 (0)