Skip to content

Commit 890b1d7

Browse files
committed
Move checksum check from net_processing to net
This removes the m_valid_checksum member from CNetMessage. Instead, GetMessage() returns an Optional. Additionally, GetMessage() has been given an out parameter to be used to hold error information. For now it is specifically a uint32_t used to hold the raw size of the corrupt message. The checksum check is now done in GetMessage.
1 parent 2716647 commit 890b1d7

File tree

4 files changed

+45
-44
lines changed

4 files changed

+45
-44
lines changed

src/net.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -595,25 +595,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
595595
while (nBytes > 0) {
596596
// absorb network data
597597
int handled = m_deserializer->Read(pch, nBytes);
598-
if (handled < 0) return false;
598+
if (handled < 0) {
599+
return false;
600+
}
599601

600602
pch += handled;
601603
nBytes -= handled;
602604

603605
if (m_deserializer->Complete()) {
604606
// decompose a transport agnostic CNetMessage from the deserializer
605-
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time);
607+
uint32_t out_err_raw_size{0};
608+
Optional<CNetMessage> result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)};
609+
if (!result) {
610+
// store the size of the corrupt message
611+
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
612+
continue;
613+
}
606614

607615
//store received bytes per message command
608616
//to prevent a memory DOS, only allow valid commands
609-
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
617+
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command);
610618
if (i == mapRecvBytesPerMsgCmd.end())
611619
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
612620
assert(i != mapRecvBytesPerMsgCmd.end());
613-
i->second += msg.m_raw_message_size;
621+
i->second += result->m_raw_message_size;
614622

615623
// push the message to the process queue,
616-
vRecvMsg.push_back(std::move(msg));
624+
vRecvMsg.push_back(std::move(*result));
617625

618626
complete = true;
619627
}
@@ -679,37 +687,36 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
679687
return data_hash;
680688
}
681689

682-
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time)
690+
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size)
683691
{
684692
// decompose a single CNetMessage from the TransportDeserializer
685-
CNetMessage msg(std::move(vRecv));
693+
Optional<CNetMessage> msg(std::move(vRecv));
686694

687695
// store state about valid header, netmagic and checksum
688-
msg.m_valid_header = hdr.IsValid(message_start);
689-
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
696+
msg->m_valid_header = hdr.IsValid(message_start);
697+
msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
690698
uint256 hash = GetMessageHash();
691699

692-
// store command string, payload size
693-
msg.m_command = hdr.GetCommand();
694-
msg.m_message_size = hdr.nMessageSize;
695-
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
700+
// store command string, time, and sizes
701+
msg->m_command = hdr.GetCommand();
702+
msg->m_time = time;
703+
msg->m_message_size = hdr.nMessageSize;
704+
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
696705

697706
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
698707
RandAddEvent(ReadLE32(hash.begin()));
699708

700-
msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0);
701-
if (!msg.m_valid_checksum) {
709+
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
702710
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
703-
SanitizeString(msg.m_command), msg.m_message_size,
711+
SanitizeString(msg->m_command), msg->m_message_size,
704712
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
705713
HexStr(hdr.pchChecksum),
706714
m_node_id);
715+
out_err_raw_size = msg->m_raw_message_size;
716+
msg = nullopt;
707717
}
708718

709-
// store receive time
710-
msg.m_time = time;
711-
712-
// reset the network deserializer (prepare for the next message)
719+
// Always reset the network deserializer (prepare for the next message)
713720
Reset();
714721
return msg;
715722
}

src/net.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
#include <crypto/siphash.h>
1515
#include <hash.h>
1616
#include <limitedmap.h>
17-
#include <netaddress.h>
1817
#include <net_permissions.h>
18+
#include <netaddress.h>
19+
#include <optional.h>
1920
#include <policy/feerate.h>
2021
#include <protocol.h>
2122
#include <random.h>
@@ -706,7 +707,6 @@ class CNetMessage {
706707
std::chrono::microseconds m_time{0}; //!< time of message receipt
707708
bool m_valid_netmagic = false;
708709
bool m_valid_header = false;
709-
bool m_valid_checksum = false;
710710
uint32_t m_message_size{0}; //!< size of the payload
711711
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
712712
std::string m_command;
@@ -732,7 +732,7 @@ class TransportDeserializer {
732732
// read and deserialize data
733733
virtual int Read(const char *data, unsigned int bytes) = 0;
734734
// decomposes a message from the context
735-
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0;
735+
virtual Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0;
736736
virtual ~TransportDeserializer() {}
737737
};
738738

@@ -790,7 +790,7 @@ class V1TransportDeserializer final : public TransportDeserializer
790790
if (ret < 0) Reset();
791791
return ret;
792792
}
793-
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override;
793+
Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
794794
};
795795

796796
/** The TransportSerializer prepares messages for the network transport

src/net_processing.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,17 +3886,8 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
38863886
// Message size
38873887
unsigned int nMessageSize = msg.m_message_size;
38883888

3889-
// Checksum
3890-
CDataStream& vRecv = msg.m_recv;
3891-
if (!msg.m_valid_checksum)
3892-
{
3893-
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
3894-
SanitizeString(msg_type), nMessageSize, pfrom->GetId());
3895-
return fMoreWork;
3896-
}
3897-
38983889
try {
3899-
ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc);
3890+
ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc);
39003891
if (interruptMsgProc)
39013892
return false;
39023893
if (!pfrom->vRecvGetData.empty())

src/test/fuzz/p2p_transport_deserializer.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ void test_one_input(const std::vector<uint8_t>& buffer)
3232
n_bytes -= handled;
3333
if (deserializer.Complete()) {
3434
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
35-
const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time);
36-
assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE);
37-
assert(msg.m_raw_message_size <= buffer.size());
38-
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
39-
assert(msg.m_time == m_time);
40-
if (msg.m_valid_header) {
41-
assert(msg.m_valid_netmagic);
42-
}
43-
if (!msg.m_valid_netmagic) {
44-
assert(!msg.m_valid_header);
35+
uint32_t out_err_raw_size{0};
36+
Optional<CNetMessage> result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)};
37+
if (result) {
38+
assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
39+
assert(result->m_raw_message_size <= buffer.size());
40+
assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
41+
assert(result->m_time == m_time);
42+
if (result->m_valid_header) {
43+
assert(result->m_valid_netmagic);
44+
}
45+
if (!result->m_valid_netmagic) {
46+
assert(!result->m_valid_header);
47+
}
4548
}
4649
}
4750
}

0 commit comments

Comments
 (0)