Skip to content

Commit efecb74

Browse files
committed
Use adapter pattern for the network deserializer
1 parent 1a5c656 commit efecb74

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

src/net.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -571,18 +571,13 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
571571
nRecvBytes += nBytes;
572572
while (nBytes > 0) {
573573
// absorb network data
574-
int handled;
575-
if (!m_deserializer->in_data)
576-
handled = m_deserializer->readHeader(pch, nBytes);
577-
else
578-
handled = m_deserializer->readData(pch, nBytes);
579-
574+
int handled = m_deserializer->Read(pch, nBytes);
580575
if (handled < 0) {
581576
m_deserializer->Reset();
582577
return false;
583578
}
584579

585-
if (m_deserializer->in_data && m_deserializer->hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
580+
if (m_deserializer->OversizedMessageDetected()) {
586581
LogPrint(BCLog::NET, "Oversized message from peer=%i, disconnecting\n", GetId());
587582
m_deserializer->Reset();
588583
return false;
@@ -591,13 +586,13 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
591586
pch += handled;
592587
nBytes -= handled;
593588

594-
if (m_deserializer->complete()) {
589+
if (m_deserializer->Complete()) {
595590
// decompose a transport agnostic CNetMessage from the deserializer
596591
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), nTimeMicros);
597592

598593
//store received bytes per message command
599594
//to prevent a memory DOS, only allow valid commands
600-
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(m_deserializer->hdr.pchCommand);
595+
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
601596
if (i == mapRecvBytesPerMsgCmd.end())
602597
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
603598
assert(i != mapRecvBytesPerMsgCmd.end());
@@ -639,7 +634,7 @@ int CNode::GetSendVersion() const
639634
return nSendVersion;
640635
}
641636

642-
int TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
637+
int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
643638
{
644639
// copy data to temporary parsing buffer
645640
unsigned int nRemaining = 24 - nHdrPos;
@@ -670,7 +665,7 @@ int TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
670665
return nCopy;
671666
}
672667

673-
int TransportDeserializer::readData(const char *pch, unsigned int nBytes)
668+
int V1TransportDeserializer::readData(const char *pch, unsigned int nBytes)
674669
{
675670
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
676671
unsigned int nCopy = std::min(nRemaining, nBytes);
@@ -687,15 +682,15 @@ int TransportDeserializer::readData(const char *pch, unsigned int nBytes)
687682
return nCopy;
688683
}
689684

690-
const uint256& TransportDeserializer::GetMessageHash() const
685+
const uint256& V1TransportDeserializer::GetMessageHash() const
691686
{
692-
assert(complete());
687+
assert(Complete());
693688
if (data_hash.IsNull())
694689
hasher.Finalize(data_hash.begin());
695690
return data_hash;
696691
}
697692

698-
CNetMessage TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) {
693+
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) {
699694
// decompose a single CNetMessage from the TransportDeserializer
700695
CNetMessage msg(std::move(vRecv));
701696

@@ -2708,7 +2703,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
27082703
LogPrint(BCLog::NET, "Added connection peer=%d\n", id);
27092704
}
27102705

2711-
m_deserializer = MakeUnique<TransportDeserializer>(TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
2706+
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
27122707
}
27132708

27142709
CNode::~CNode()

src/net.h

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -637,20 +637,40 @@ class CNetMessage {
637637
* transport protocol agnostic CNetMessage (command & payload)
638638
*/
639639
class TransportDeserializer {
640+
public:
641+
// prepare for next message
642+
virtual void Reset() = 0;
643+
// returns true if the current deserialization is complete
644+
virtual bool Complete() const = 0;
645+
// checks if the potential message in deserialization is oversized
646+
virtual bool OversizedMessageDetected() const = 0;
647+
// set the serialization context version
648+
virtual void SetVersion(int version) = 0;
649+
// read and deserialize data
650+
virtual int Read(const char *data, unsigned int bytes) = 0;
651+
// decomposes a message from the context
652+
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) = 0;
653+
virtual ~TransportDeserializer() {}
654+
};
655+
656+
class V1TransportDeserializer : public TransportDeserializer
657+
{
640658
private:
641659
mutable CHash256 hasher;
642660
mutable uint256 data_hash;
643-
public:
644661
bool in_data; // parsing header (false) or data (true)
645-
646662
CDataStream hdrbuf; // partially received header
647663
CMessageHeader hdr; // complete header
648-
unsigned int nHdrPos;
649-
650664
CDataStream vRecv; // received message data
665+
unsigned int nHdrPos;
651666
unsigned int nDataPos;
652667

653-
TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
668+
const uint256& GetMessageHash() const;
669+
int readHeader(const char *pch, unsigned int nBytes);
670+
int readData(const char *pch, unsigned int nBytes);
671+
public:
672+
673+
V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
654674
Reset();
655675
}
656676

@@ -664,25 +684,23 @@ class TransportDeserializer {
664684
data_hash.SetNull();
665685
hasher.Reset();
666686
}
667-
668-
bool complete() const
687+
bool Complete() const
669688
{
670689
if (!in_data)
671690
return false;
672691
return (hdr.nMessageSize == nDataPos);
673692
}
674-
675-
const uint256& GetMessageHash() const;
676-
677693
void SetVersion(int nVersionIn)
678694
{
679695
hdrbuf.SetVersion(nVersionIn);
680696
vRecv.SetVersion(nVersionIn);
681697
}
682-
683-
int readHeader(const char *pch, unsigned int nBytes);
684-
int readData(const char *pch, unsigned int nBytes);
685-
698+
bool OversizedMessageDetected() const {
699+
return (in_data && hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH);
700+
}
701+
int Read(const char *pch, unsigned int nBytes) {
702+
return in_data ? readData(pch, nBytes) : readHeader(pch, nBytes);
703+
}
686704
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time);
687705
};
688706

0 commit comments

Comments
 (0)