Skip to content

Commit badca85

Browse files
committed
Merge #16202: p2p: Refactor network message deserialization
ed2dc5e Add override/final modifiers to V1TransportDeserializer (Pieter Wuille) f342a5e Make resetting implicit in TransportDeserializer::Read() (Pieter Wuille) 6a91499 Remove oversized message detection from log and interface (Pieter Wuille) b0e10ff Force CNetMessage::m_recv to use std::move (Jonas Schnelli) efecb74 Use adapter pattern for the network deserializer (Jonas Schnelli) 1a5c656 Remove transport protocol knowhow from CNetMessage / net processing (Jonas Schnelli) 6294ecd Refactor: split network transport deserializing from message container (Jonas Schnelli) Pull request description: **This refactors the network message deserialization.** * It transforms the `CNetMessage` into a transport protocol agnostic message container. * A new class `TransportDeserializer` (unique pointer of `CNode`) is introduced, handling the network buffer reading and the decomposing to a `CNetMessage` * **No behavioral changes** (in terms of disconnecting, punishing) * Moves the checksum finalizing into the `SocketHandler` thread (finalizing was in `ProcessMessages` before) The **optional last commit** makes the `TransportDeserializer` following an adapter pattern (polymorphic interface) to make it easier to later add a V2 transport protocol deserializer. Intentionally not touching the sending part. Pre-Requirement for BIP324 (v2 message transport protocol). Replacement for #14046 and inspired by a [comment](bitcoin/bitcoin#14046 (comment)) from sipa ACKs for top commit: promag: Code review ACK ed2dc5e. marcinja: Code review ACK ed2dc5e ryanofsky: Code review ACK ed2dc5e. 4 cleanup commits added since last review. Unaddressed comments: ariard: Code review and tested ACK ed2dc5e. Tree-SHA512: bab8d87464e2e8742529e488ddcdc8650f0c2025c9130913df00a0b17ecdb9a525061cbbbd0de0251b76bf75a8edb72e3ad0dbf5b79e26f2ad05d61b4e4ded6d
2 parents f8cc2b9 + ed2dc5e commit badca85

File tree

4 files changed

+138
-76
lines changed

4 files changed

+138
-76
lines changed

src/net.cpp

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -567,42 +567,28 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
567567
nLastRecv = nTimeMicros / 1000000;
568568
nRecvBytes += nBytes;
569569
while (nBytes > 0) {
570-
571-
// get current incomplete message, or create a new one
572-
if (vRecvMsg.empty() ||
573-
vRecvMsg.back().complete())
574-
vRecvMsg.push_back(CNetMessage(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
575-
576-
CNetMessage& msg = vRecvMsg.back();
577-
578570
// absorb network data
579-
int handled;
580-
if (!msg.in_data)
581-
handled = msg.readHeader(pch, nBytes);
582-
else
583-
handled = msg.readData(pch, nBytes);
584-
585-
if (handled < 0)
586-
return false;
587-
588-
if (msg.in_data && msg.hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
589-
LogPrint(BCLog::NET, "Oversized message from peer=%i, disconnecting\n", GetId());
590-
return false;
591-
}
571+
int handled = m_deserializer->Read(pch, nBytes);
572+
if (handled < 0) return false;
592573

593574
pch += handled;
594575
nBytes -= handled;
595576

596-
if (msg.complete()) {
577+
if (m_deserializer->Complete()) {
578+
// decompose a transport agnostic CNetMessage from the deserializer
579+
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), nTimeMicros);
580+
597581
//store received bytes per message command
598582
//to prevent a memory DOS, only allow valid commands
599-
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.hdr.pchCommand);
583+
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
600584
if (i == mapRecvBytesPerMsgCmd.end())
601585
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
602586
assert(i != mapRecvBytesPerMsgCmd.end());
603-
i->second += msg.hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
587+
i->second += msg.m_raw_message_size;
588+
589+
// push the message to the process queue,
590+
vRecvMsg.push_back(std::move(msg));
604591

605-
msg.nTime = nTimeMicros;
606592
complete = true;
607593
}
608594
}
@@ -636,8 +622,7 @@ int CNode::GetSendVersion() const
636622
return nSendVersion;
637623
}
638624

639-
640-
int CNetMessage::readHeader(const char *pch, unsigned int nBytes)
625+
int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
641626
{
642627
// copy data to temporary parsing buffer
643628
unsigned int nRemaining = 24 - nHdrPos;
@@ -658,17 +643,18 @@ int CNetMessage::readHeader(const char *pch, unsigned int nBytes)
658643
return -1;
659644
}
660645

661-
// reject messages larger than MAX_SIZE
662-
if (hdr.nMessageSize > MAX_SIZE)
646+
// reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH
647+
if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
663648
return -1;
649+
}
664650

665651
// switch state to reading message data
666652
in_data = true;
667653

668654
return nCopy;
669655
}
670656

671-
int CNetMessage::readData(const char *pch, unsigned int nBytes)
657+
int V1TransportDeserializer::readData(const char *pch, unsigned int nBytes)
672658
{
673659
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
674660
unsigned int nCopy = std::min(nRemaining, nBytes);
@@ -685,14 +671,44 @@ int CNetMessage::readData(const char *pch, unsigned int nBytes)
685671
return nCopy;
686672
}
687673

688-
const uint256& CNetMessage::GetMessageHash() const
674+
const uint256& V1TransportDeserializer::GetMessageHash() const
689675
{
690-
assert(complete());
676+
assert(Complete());
691677
if (data_hash.IsNull())
692678
hasher.Finalize(data_hash.begin());
693679
return data_hash;
694680
}
695681

682+
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) {
683+
// decompose a single CNetMessage from the TransportDeserializer
684+
CNetMessage msg(std::move(vRecv));
685+
686+
// store state about valid header, netmagic and checksum
687+
msg.m_valid_header = hdr.IsValid(message_start);
688+
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
689+
uint256 hash = GetMessageHash();
690+
691+
// store command string, payload size
692+
msg.m_command = hdr.GetCommand();
693+
msg.m_message_size = hdr.nMessageSize;
694+
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
695+
696+
msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0);
697+
if (!msg.m_valid_checksum) {
698+
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n",
699+
SanitizeString(msg.m_command), msg.m_message_size,
700+
HexStr(hash.begin(), hash.begin()+CMessageHeader::CHECKSUM_SIZE),
701+
HexStr(hdr.pchChecksum, hdr.pchChecksum+CMessageHeader::CHECKSUM_SIZE));
702+
}
703+
704+
// store receive time
705+
msg.m_time = time;
706+
707+
// reset the network deserializer (prepare for the next message)
708+
Reset();
709+
return msg;
710+
}
711+
696712
size_t CConnman::SocketSendData(CNode *pnode) const EXCLUSIVE_LOCKS_REQUIRED(pnode->cs_vSend)
697713
{
698714
auto it = pnode->vSendMsg.begin();
@@ -1344,9 +1360,9 @@ void CConnman::SocketHandler()
13441360
size_t nSizeAdded = 0;
13451361
auto it(pnode->vRecvMsg.begin());
13461362
for (; it != pnode->vRecvMsg.end(); ++it) {
1347-
if (!it->complete())
1348-
break;
1349-
nSizeAdded += it->vRecv.size() + CMessageHeader::HEADER_SIZE;
1363+
// vRecvMsg contains only completed CNetMessage
1364+
// the single possible partially deserialized message are held by TransportDeserializer
1365+
nSizeAdded += it->m_raw_message_size;
13501366
}
13511367
{
13521368
LOCK(pnode->cs_vProcessMsg);
@@ -2676,6 +2692,8 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
26762692
} else {
26772693
LogPrint(BCLog::NET, "Added connection peer=%d\n", id);
26782694
}
2695+
2696+
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
26792697
}
26802698

26812699
CNode::~CNode()

src/net.h

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -609,56 +609,105 @@ class CNodeStats
609609

610610

611611

612-
612+
/** Transport protocol agnostic message container.
613+
* Ideally it should only contain receive time, payload,
614+
* command and size.
615+
*/
613616
class CNetMessage {
617+
public:
618+
CDataStream m_recv; // received message data
619+
int64_t m_time = 0; // time (in microseconds) of message receipt.
620+
bool m_valid_netmagic = false;
621+
bool m_valid_header = false;
622+
bool m_valid_checksum = false;
623+
uint32_t m_message_size = 0; // size of the payload
624+
uint32_t m_raw_message_size = 0; // used wire size of the message (including header/checksum)
625+
std::string m_command;
626+
627+
CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {}
628+
629+
void SetVersion(int nVersionIn)
630+
{
631+
m_recv.SetVersion(nVersionIn);
632+
}
633+
};
634+
635+
/** The TransportDeserializer takes care of holding and deserializing the
636+
* network receive buffer. It can deserialize the network buffer into a
637+
* transport protocol agnostic CNetMessage (command & payload)
638+
*/
639+
class TransportDeserializer {
640+
public:
641+
// returns true if the current deserialization is complete
642+
virtual bool Complete() const = 0;
643+
// set the serialization context version
644+
virtual void SetVersion(int version) = 0;
645+
// read and deserialize data
646+
virtual int Read(const char *data, unsigned int bytes) = 0;
647+
// decomposes a message from the context
648+
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) = 0;
649+
virtual ~TransportDeserializer() {}
650+
};
651+
652+
class V1TransportDeserializer final : public TransportDeserializer
653+
{
614654
private:
615655
mutable CHash256 hasher;
616656
mutable uint256 data_hash;
617-
public:
618657
bool in_data; // parsing header (false) or data (true)
619-
620658
CDataStream hdrbuf; // partially received header
621659
CMessageHeader hdr; // complete header
622-
unsigned int nHdrPos;
623-
624660
CDataStream vRecv; // received message data
661+
unsigned int nHdrPos;
625662
unsigned int nDataPos;
626663

627-
int64_t nTime; // time (in microseconds) of message receipt.
664+
const uint256& GetMessageHash() const;
665+
int readHeader(const char *pch, unsigned int nBytes);
666+
int readData(const char *pch, unsigned int nBytes);
628667

629-
CNetMessage(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
668+
void Reset() {
669+
vRecv.clear();
670+
hdrbuf.clear();
630671
hdrbuf.resize(24);
631672
in_data = false;
632673
nHdrPos = 0;
633674
nDataPos = 0;
634-
nTime = 0;
675+
data_hash.SetNull();
676+
hasher.Reset();
635677
}
636678

637-
bool complete() const
679+
public:
680+
681+
V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
682+
Reset();
683+
}
684+
685+
bool Complete() const override
638686
{
639687
if (!in_data)
640688
return false;
641689
return (hdr.nMessageSize == nDataPos);
642690
}
643-
644-
const uint256& GetMessageHash() const;
645-
646-
void SetVersion(int nVersionIn)
691+
void SetVersion(int nVersionIn) override
647692
{
648693
hdrbuf.SetVersion(nVersionIn);
649694
vRecv.SetVersion(nVersionIn);
650695
}
651-
652-
int readHeader(const char *pch, unsigned int nBytes);
653-
int readData(const char *pch, unsigned int nBytes);
696+
int Read(const char *pch, unsigned int nBytes) override {
697+
int ret = in_data ? readData(pch, nBytes) : readHeader(pch, nBytes);
698+
if (ret < 0) Reset();
699+
return ret;
700+
}
701+
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) override;
654702
};
655703

656-
657704
/** Information about a peer */
658705
class CNode
659706
{
660707
friend class CConnman;
661708
public:
709+
std::unique_ptr<TransportDeserializer> m_deserializer;
710+
662711
// socket
663712
std::atomic<ServiceFlags> nServices{NODE_NONE};
664713
SOCKET hSocket GUARDED_BY(cs_hSocket);

src/net_processing.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3272,49 +3272,45 @@ bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& inter
32723272
return false;
32733273
// Just take one message
32743274
msgs.splice(msgs.begin(), pfrom->vProcessMsg, pfrom->vProcessMsg.begin());
3275-
pfrom->nProcessQueueSize -= msgs.front().vRecv.size() + CMessageHeader::HEADER_SIZE;
3275+
pfrom->nProcessQueueSize -= msgs.front().m_raw_message_size;
32763276
pfrom->fPauseRecv = pfrom->nProcessQueueSize > connman->GetReceiveFloodSize();
32773277
fMoreWork = !pfrom->vProcessMsg.empty();
32783278
}
32793279
CNetMessage& msg(msgs.front());
32803280

32813281
msg.SetVersion(pfrom->GetRecvVersion());
3282-
// Scan for message start
3283-
if (memcmp(msg.hdr.pchMessageStart, chainparams.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) {
3284-
LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.hdr.GetCommand()), pfrom->GetId());
3282+
// Check network magic
3283+
if (!msg.m_valid_netmagic) {
3284+
LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
32853285
pfrom->fDisconnect = true;
32863286
return false;
32873287
}
32883288

3289-
// Read header
3290-
CMessageHeader& hdr = msg.hdr;
3291-
if (!hdr.IsValid(chainparams.MessageStart()))
3289+
// Check header
3290+
if (!msg.m_valid_header)
32923291
{
3293-
LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(hdr.GetCommand()), pfrom->GetId());
3292+
LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
32943293
return fMoreWork;
32953294
}
3296-
std::string strCommand = hdr.GetCommand();
3295+
const std::string& strCommand = msg.m_command;
32973296

32983297
// Message size
3299-
unsigned int nMessageSize = hdr.nMessageSize;
3298+
unsigned int nMessageSize = msg.m_message_size;
33003299

33013300
// Checksum
3302-
CDataStream& vRecv = msg.vRecv;
3303-
const uint256& hash = msg.GetMessageHash();
3304-
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0)
3301+
CDataStream& vRecv = msg.m_recv;
3302+
if (!msg.m_valid_checksum)
33053303
{
3306-
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR expected %s was %s\n", __func__,
3307-
SanitizeString(strCommand), nMessageSize,
3308-
HexStr(hash.begin(), hash.begin()+CMessageHeader::CHECKSUM_SIZE),
3309-
HexStr(hdr.pchChecksum, hdr.pchChecksum+CMessageHeader::CHECKSUM_SIZE));
3304+
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
3305+
SanitizeString(strCommand), nMessageSize, pfrom->GetId());
33103306
return fMoreWork;
33113307
}
33123308

33133309
// Process message
33143310
bool fRet = false;
33153311
try
33163312
{
3317-
fRet = ProcessMessage(pfrom, strCommand, vRecv, msg.nTime, chainparams, connman, interruptMsgProc);
3313+
fRet = ProcessMessage(pfrom, strCommand, vRecv, msg.m_time, chainparams, connman, interruptMsgProc);
33183314
if (interruptMsgProc)
33193315
return false;
33203316
if (!pfrom->vRecvGetData.empty())

test/functional/p2p_invalid_messages.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ def run_test(self):
101101
msg_over_size = msg_unrecognized(str_data="b" * (valid_data_limit + 1))
102102
assert len(msg_over_size.serialize()) == (msg_limit + 1)
103103

104-
with node.assert_debug_log(["Oversized message from peer=4, disconnecting"]):
105-
# An unknown message type (or *any* message type) over
106-
# MAX_PROTOCOL_MESSAGE_LENGTH should result in a disconnect.
107-
node.p2p.send_message(msg_over_size)
108-
node.p2p.wait_for_disconnect(timeout=4)
104+
# An unknown message type (or *any* message type) over
105+
# MAX_PROTOCOL_MESSAGE_LENGTH should result in a disconnect.
106+
node.p2p.send_message(msg_over_size)
107+
node.p2p.wait_for_disconnect(timeout=4)
109108

110109
node.disconnect_p2ps()
111110
conn = node.add_p2p_connection(P2PDataStore())
@@ -168,7 +167,7 @@ def swap_magic_bytes():
168167

169168
def test_checksum(self):
170169
conn = self.nodes[0].add_p2p_connection(P2PDataStore())
171-
with self.nodes[0].assert_debug_log(['ProcessMessages(badmsg, 2 bytes): CHECKSUM ERROR expected 78df0a04 was ffffffff']):
170+
with self.nodes[0].assert_debug_log(['CHECKSUM ERROR (badmsg, 2 bytes), expected 78df0a04 was ffffffff']):
172171
msg = conn.build_message(msg_unrecognized(str_data="d"))
173172
cut_len = (
174173
4 + # magic

0 commit comments

Comments
 (0)