Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions src/blockfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// file COPYING or http://www.opensource.org/licenses/mit-license.php.

#include <mutex>
#include <sstream>
#include <set>

#include <blockfilter.h>
Expand All @@ -13,6 +12,7 @@
#include <script/script.h>
#include <streams.h>
#include <util/golombrice.h>
#include <util/string.h>

/// SerType used to serialize parameters in GCS filter encoding.
static constexpr int GCS_SER_TYPE = SER_NETWORK;
Expand Down Expand Up @@ -179,19 +179,7 @@ const std::set<BlockFilterType>& AllBlockFilterTypes()

const std::string& ListBlockFilterTypes()
{
static std::string type_list;

static std::once_flag flag;
std::call_once(flag, []() {
std::stringstream ret;
bool first = true;
for (auto entry : g_filter_types) {
if (!first) ret << ", ";
ret << entry.second;
first = false;
}
type_list = ret.str();
});
static std::string type_list{Join(g_filter_types, ", ", [](const auto& entry) { return entry.second; })};

return type_list;
}
Expand Down
12 changes: 6 additions & 6 deletions src/test/util_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,15 @@ BOOST_AUTO_TEST_CASE(span_write_bytes)
BOOST_AUTO_TEST_CASE(util_Join)
{
// Normal version
BOOST_CHECK_EQUAL(Join({}, ", "), "");
BOOST_CHECK_EQUAL(Join({"foo"}, ", "), "foo");
BOOST_CHECK_EQUAL(Join({"foo", "bar"}, ", "), "foo, bar");
BOOST_CHECK_EQUAL(Join(std::vector<std::string>{}, ", "), "");
BOOST_CHECK_EQUAL(Join(std::vector<std::string>{"foo"}, ", "), "foo");
BOOST_CHECK_EQUAL(Join(std::vector<std::string>{"foo", "bar"}, ", "), "foo, bar");

// Version with unary operator
const auto op_upper = [](const std::string& s) { return ToUpper(s); };
BOOST_CHECK_EQUAL(Join<std::string>({}, ", ", op_upper), "");
BOOST_CHECK_EQUAL(Join<std::string>({"foo"}, ", ", op_upper), "FOO");
BOOST_CHECK_EQUAL(Join<std::string>({"foo", "bar"}, ", ", op_upper), "FOO, BAR");
BOOST_CHECK_EQUAL(Join(std::list<std::string>{}, ", ", op_upper), "");
BOOST_CHECK_EQUAL(Join(std::list<std::string>{"foo"}, ", ", op_upper), "FOO");
BOOST_CHECK_EQUAL(Join(std::list<std::string>{"foo", "bar"}, ", ", op_upper), "FOO, BAR");
}

BOOST_AUTO_TEST_CASE(util_ReplaceAll)
Expand Down
36 changes: 16 additions & 20 deletions src/util/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,30 @@ void ReplaceAll(std::string& in_out, const std::string& search, const std::strin
}

/**
* Join a list of items
* Join all container items. Typically used to concatenate strings but accepts
* containers with elements of any type.
*
* @param list The list to join
* @param separator The separator
* @param unary_op Apply this operator to each item in the list
* @param container The items to join
* @param separator The separator
* @param unary_op Apply this operator to each item
*/
template <typename T, typename BaseType, typename UnaryOp>
auto Join(const std::vector<T>& list, const BaseType& separator, UnaryOp unary_op)
-> decltype(unary_op(list.at(0)))
template <typename C, typename S, typename UnaryOp>
auto Join(const C& container, const S& separator, UnaryOp unary_op)
{
decltype(unary_op(list.at(0))) ret;
for (size_t i = 0; i < list.size(); ++i) {
if (i > 0) ret += separator;
ret += unary_op(list.at(i));
decltype(unary_op(*container.begin())) ret;
bool first{true};
for (const auto& item : container) {
if (!first) ret += separator;
ret += unary_op(item);
first = false;
}
return ret;
}

template <typename T, typename T2>
T Join(const std::vector<T>& list, const T2& separator)
template <typename C, typename S>
auto Join(const C& container, const S& separator)
{
return Join(list, separator, [](const T& i) { return i; });
}

// Explicit overload needed for c_str arguments, which would otherwise cause a substitution failure in the template above.
inline std::string Join(const std::vector<std::string>& list, std::string_view separator)
{
return Join<std::string>(list, separator);
return Join(container, separator, [](const auto& i) { return i; });
}

/**
Expand Down
10 changes: 7 additions & 3 deletions src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,16 +1767,20 @@ bool CWallet::LoadWalletFlags(uint64_t flags)
return true;
}

bool CWallet::AddWalletFlags(uint64_t flags)
void CWallet::InitWalletFlags(uint64_t flags)
{
LOCK(cs_wallet);

// We should never be writing unknown non-tolerable wallet flags
assert(((flags & KNOWN_WALLET_FLAGS) >> 32) == (flags >> 32));
// This should only be used once, when creating a new wallet - so current flags are expected to be blank
assert(m_wallet_flags == 0);

if (!WalletBatch(GetDatabase()).WriteWalletFlags(flags)) {
throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed");
}

return LoadWalletFlags(flags);
if (!LoadWalletFlags(flags)) assert(false);
}

int64_t CWalletTx::GetTxTime() const
Expand Down Expand Up @@ -4774,7 +4778,7 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain* chain, interfaces::C
{
walletInstance->SetMinVersion(FEATURE_LATEST);

walletInstance->AddWalletFlags(wallet_creation_flags);
walletInstance->InitWalletFlags(wallet_creation_flags);

// Only create LegacyScriptPubKeyMan when not descriptor wallet
if (!walletInstance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) {
Expand Down
5 changes: 3 additions & 2 deletions src/wallet/wallet.h
Original file line number Diff line number Diff line change
Expand Up @@ -1427,8 +1427,9 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati
bool IsWalletFlagSet(uint64_t flag) const override;

/** overwrite all flags by the given uint64_t
returns false if unknown, non-tolerable flags are present */
bool AddWalletFlags(uint64_t flags);
flags must be uninitialised (or 0)
only known flags may be present */
void InitWalletFlags(uint64_t flags);
Comment on lines 1429 to 1432
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation against KNOWN_WALLET_FLAGS.

Consider adding validation in the implementation to ensure that:

  1. The input flags are a subset of KNOWN_WALLET_FLAGS
  2. The flags are uninitialized (0) before initialization

Example validation:

void CWallet::InitWalletFlags(uint64_t flags)
{
    if (m_wallet_flags != 0) {
        throw std::logic_error("InitWalletFlags called with initialized flags");
    }
    if ((flags & ~KNOWN_WALLET_FLAGS) != 0) {
        throw std::logic_error("InitWalletFlags called with unknown flags");
    }
    m_wallet_flags = flags;
}

/** Loads the flags into the wallet. (used by LoadWallet) */
bool LoadWalletFlags(uint64_t flags);

Expand Down
2 changes: 1 addition & 1 deletion src/wallet/wallettool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ static void WalletCreate(CWallet* wallet_instance, uint64_t wallet_creation_flag
} else {
wallet_instance->SetMinVersion(FEATURE_COMPRPUBKEY);
}
wallet_instance->SetWalletFlag(wallet_creation_flags);
wallet_instance->InitWalletFlags(wallet_creation_flags);

if (!wallet_instance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) {
// TODO: use here SetupGeneration instead, such as: spk_man->SetupGeneration(false);
Expand Down
3 changes: 3 additions & 0 deletions test/functional/wallet_abandonconflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class AbandonConflictTest(BitcoinTestFramework):
def set_test_params(self):
self.num_nodes = 2
self.extra_args = [["-minrelaytxfee=0.00001"], []]
# whitelist peers to speed up tx relay / mempool sync
for args in self.extra_args:
args.append("-whitelist=noban@127.0.0.1")

def skip_test_if_missing_module(self):
self.skip_if_no_wallet()
Expand Down
3 changes: 3 additions & 0 deletions test/functional/wallet_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def set_test_params(self):
['-limitdescendantcount=3'], # Limit mempool descendants as a hack to have wallet txs rejected from the mempool
[],
]
# whitelist peers to speed up tx relay / mempool sync
for args in self.extra_args:
args.append("-whitelist=noban@127.0.0.1")

def skip_test_if_missing_module(self):
self.skip_if_no_wallet()
Expand Down
4 changes: 2 additions & 2 deletions test/functional/wallet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def set_test_params(self):
self.num_nodes = 4
if self.options.descriptors:
self.extra_args = [[
"-acceptnonstdtxn=1"
"-acceptnonstdtxn=1", "-whitelist=noban@127.0.0.1"
] for i in range(self.num_nodes)]
else:
self.extra_args = [[
"-acceptnonstdtxn=1",
"-acceptnonstdtxn=1", "-whitelist=noban@127.0.0.1",
'-usehd={:d}'.format(i%2==0)
] for i in range(self.num_nodes)]
self.setup_clean_chain = True
Expand Down
20 changes: 7 additions & 13 deletions test/functional/wallet_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from test_framework.test_framework import BitcoinTestFramework
from test_framework.util import (
assert_raises_rpc_error,
assert_greater_than,
assert_greater_than_or_equal,
assert_equal,
)


Expand Down Expand Up @@ -76,23 +75,18 @@ def run_test(self):

self.log.info('Check a timeout less than the limit')
MAX_VALUE = 100000000
expected_time = self.mocktime + MAX_VALUE - 600
now = int(time.time())
self.nodes[0].setmocktime(now)
expected_time = now + MAX_VALUE - 600
self.nodes[0].walletpassphrase(passphrase2, MAX_VALUE - 600)
self.bump_mocktime(1)
# give buffer for walletpassphrase, since it iterates over all encrypted keys
expected_time_with_buffer = self.mocktime + MAX_VALUE - 600
actual_time = self.nodes[0].getwalletinfo()['unlocked_until']
assert_greater_than_or_equal(actual_time, expected_time)
assert_greater_than(expected_time_with_buffer, actual_time)
assert_equal(actual_time, expected_time)

self.log.info('Check a timeout greater than the limit')
expected_time = self.mocktime + MAX_VALUE - 1
expected_time = now + MAX_VALUE
self.nodes[0].walletpassphrase(passphrase2, MAX_VALUE + 1000)
self.bump_mocktime(1)
expected_time_with_buffer = self.mocktime + MAX_VALUE
actual_time = self.nodes[0].getwalletinfo()['unlocked_until']
assert_greater_than_or_equal(actual_time, expected_time)
assert_greater_than(expected_time_with_buffer, actual_time)
assert_equal(actual_time, expected_time)


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions test/functional/wallet_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def set_test_params(self):
["-maxapsfee=0.00000293"],
["-maxapsfee=0.00000294"],
]
# whitelist peers to speed up tx relay / mempool sync
for args in self.extra_args:
args.append("-whitelist=noban@127.0.0.1")

self.rpc_timeout = 480
self.supports_cli = False

Expand Down
3 changes: 3 additions & 0 deletions test/functional/wallet_hd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def set_test_params(self):
self.setup_clean_chain = True
self.num_nodes = 2
self.extra_args = [['-usehd=0'], ['-usehd=1', '-keypool=0']]
# whitelist peers to speed up tx relay / mempool sync
for args in self.extra_args:
args.append("-whitelist=noban@127.0.0.1")

def setup_network(self):
self.add_nodes(self.num_nodes, self.extra_args)
Expand Down
3 changes: 3 additions & 0 deletions test/functional/wallet_importdescriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def set_test_params(self):
self.extra_args = [[],
["-keypool=5"]
]
# whitelist peers to speed up tx relay / mempool sync
for args in self.extra_args:
args.append("-whitelist=noban@127.0.0.1")
self.setup_clean_chain = True
self.wallet_names = []

Expand Down
2 changes: 2 additions & 0 deletions test/functional/wallet_listreceivedby.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
class ReceivedByTest(BitcoinTestFramework):
def set_test_params(self):
self.num_nodes = 2
# whitelist peers to speed up tx relay / mempool sync
self.extra_args = [["-whitelist=noban@127.0.0.1"]] * self.num_nodes

def skip_test_if_missing_module(self):
self.skip_if_no_wallet()
Expand Down
2 changes: 2 additions & 0 deletions test/functional/wallet_listsinceblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ListSinceBlockTest(BitcoinTestFramework):
def set_test_params(self):
self.num_nodes = 4
self.setup_clean_chain = True
# whitelist peers to speed up tx relay / mempool sync
self.extra_args = [["-whitelist=noban@127.0.0.1"]] * self.num_nodes

def skip_test_if_missing_module(self):
self.skip_if_no_wallet()
Expand Down
Loading