Skip to content

Commit 52d4ae4

Browse files
committed
Give V1TransportDeserializer CChainParams& member
This adds a CChainParams& member to V1TransportDeserializer member, and use it in place of many Params() calls. In addition to reducing the number of calls to a global, this removes a parameter from GetMessage (and will later allow us to remove one from CMessageHeader::IsValid())
1 parent 5bceef6 commit 52d4ae4

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

src/net.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <net.h>
1111

1212
#include <banman.h>
13-
#include <chainparams.h>
1413
#include <clientversion.h>
1514
#include <consensus/consensus.h>
1615
#include <crypto/sha256.h>
@@ -615,7 +614,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
615614
if (m_deserializer->Complete()) {
616615
// decompose a transport agnostic CNetMessage from the deserializer
617616
uint32_t out_err_raw_size{0};
618-
Optional<CNetMessage> result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)};
617+
Optional<CNetMessage> result{m_deserializer->GetMessage(time, out_err_raw_size)};
619618
if (!result) {
620619
// store the size of the corrupt message
621620
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
@@ -697,22 +696,23 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
697696
return data_hash;
698697
}
699698

700-
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size)
699+
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, uint32_t& out_err_raw_size)
701700
{
702701
// decompose a single CNetMessage from the TransportDeserializer
703702
Optional<CNetMessage> msg(std::move(vRecv));
704703

705704
// store state about valid header, netmagic and checksum
706-
msg->m_valid_header = hdr.IsValid(message_start);
707-
msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
708-
uint256 hash = GetMessageHash();
705+
msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart());
706+
msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0);
709707

710708
// store command string, time, and sizes
711709
msg->m_command = hdr.GetCommand();
712710
msg->m_time = time;
713711
msg->m_message_size = hdr.nMessageSize;
714712
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
715713

714+
uint256 hash = GetMessageHash();
715+
716716
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
717717
RandAddEvent(ReadLE32(hash.begin()));
718718

@@ -2846,7 +2846,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
28462846
LogPrint(BCLog::NET, "Added connection peer=%d\n", id);
28472847
}
28482848

2849-
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(GetId(), SER_NETWORK, INIT_PROTO_VERSION));
2849+
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION));
28502850
m_serializer = MakeUnique<V1TransportSerializer>(V1TransportSerializer());
28512851
}
28522852

src/net.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <addrman.h>
1111
#include <amount.h>
1212
#include <bloom.h>
13+
#include <chainparams.h>
1314
#include <compat.h>
1415
#include <crypto/siphash.h>
1516
#include <hash.h>
@@ -732,13 +733,14 @@ class TransportDeserializer {
732733
// read and deserialize data
733734
virtual int Read(const char *data, unsigned int bytes) = 0;
734735
// decomposes a message from the context
735-
virtual Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0;
736+
virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0;
736737
virtual ~TransportDeserializer() {}
737738
};
738739

739740
class V1TransportDeserializer final : public TransportDeserializer
740741
{
741742
private:
743+
const CChainParams& m_chain_params;
742744
const NodeId m_node_id; // Only for logging
743745
mutable CHash256 hasher;
744746
mutable uint256 data_hash;
@@ -765,8 +767,9 @@ class V1TransportDeserializer final : public TransportDeserializer
765767
}
766768

767769
public:
768-
V1TransportDeserializer(const NodeId node_id, int nTypeIn, int nVersionIn)
769-
: m_node_id(node_id),
770+
V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
771+
: m_chain_params(chain_params),
772+
m_node_id(node_id),
770773
hdrbuf(nTypeIn, nVersionIn),
771774
vRecv(nTypeIn, nVersionIn)
772775
{
@@ -789,7 +792,7 @@ class V1TransportDeserializer final : public TransportDeserializer
789792
if (ret < 0) Reset();
790793
return ret;
791794
}
792-
Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
795+
Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
793796
};
794797

795798
/** The TransportSerializer prepares messages for the network transport

src/test/fuzz/p2p_transport_deserializer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void initialize()
2020
void test_one_input(const std::vector<uint8_t>& buffer)
2121
{
2222
// Construct deserializer, with a dummy NodeId
23-
V1TransportDeserializer deserializer{(NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
23+
V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
2424
const char* pch = (const char*)buffer.data();
2525
size_t n_bytes = buffer.size();
2626
while (n_bytes > 0) {
@@ -33,7 +33,7 @@ void test_one_input(const std::vector<uint8_t>& buffer)
3333
if (deserializer.Complete()) {
3434
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
3535
uint32_t out_err_raw_size{0};
36-
Optional<CNetMessage> result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)};
36+
Optional<CNetMessage> result{deserializer.GetMessage(m_time, out_err_raw_size)};
3737
if (result) {
3838
assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
3939
assert(result->m_raw_message_size <= buffer.size());

0 commit comments

Comments
 (0)