Skip to content

Commit ab31b9d

Browse files
ryanofskypromag
authored andcommitted
Fix wallet unload race condition
Currently it's possible for ReleaseWallet to delete the CWallet pointer while it is processing BlockConnected, etc chain notifications. To fix this, unregister from notifications earlier in UnloadWallet instead of ReleaseWallet, and use a new RegisterSharedValidationInterface function to prevent the CValidationInterface shared_ptr from being deleted until the last notification is actually finished.
1 parent 3e50fdb commit ab31b9d

File tree

9 files changed

+62
-42
lines changed

9 files changed

+62
-42
lines changed

src/bench/wallet_balance.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ static void WalletBalance(benchmark::State& state, const bool set_dirty, const b
2323
wallet.SetupLegacyScriptPubKeyMan();
2424
bool first_run;
2525
if (wallet.LoadWallet(first_run) != DBErrors::LOAD_OK) assert(false);
26-
wallet.handleNotifications();
2726
}
28-
27+
auto handler = chain->handleNotifications({ &wallet, [](CWallet*) {} });
2928

3029
const Optional<std::string> address_mine{add_mine ? Optional<std::string>{getnewaddress(wallet)} : nullopt};
3130
if (add_watchonly) importaddress(wallet, ADDRESS_WATCHONLY);

src/interfaces/chain.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,12 @@ class LockImpl : public Chain::Lock, public UniqueLock<RecursiveMutex>
148148
using UniqueLock::UniqueLock;
149149
};
150150

151-
class NotificationsHandlerImpl : public Handler, CValidationInterface
151+
class NotificationsProxy : public CValidationInterface
152152
{
153153
public:
154-
explicit NotificationsHandlerImpl(Chain& chain, Chain::Notifications& notifications)
155-
: m_chain(chain), m_notifications(&notifications)
156-
{
157-
RegisterValidationInterface(this);
158-
}
159-
~NotificationsHandlerImpl() override { disconnect(); }
160-
void disconnect() override
161-
{
162-
if (m_notifications) {
163-
m_notifications = nullptr;
164-
UnregisterValidationInterface(this);
165-
}
166-
}
154+
explicit NotificationsProxy(std::shared_ptr<Chain::Notifications> notifications)
155+
: m_notifications(std::move(notifications)) {}
156+
virtual ~NotificationsProxy() = default;
167157
void TransactionAddedToMempool(const CTransactionRef& tx) override
168158
{
169159
m_notifications->transactionAddedToMempool(tx);
@@ -185,8 +175,26 @@ class NotificationsHandlerImpl : public Handler, CValidationInterface
185175
m_notifications->updatedBlockTip();
186176
}
187177
void ChainStateFlushed(const CBlockLocator& locator) override { m_notifications->chainStateFlushed(locator); }
188-
Chain& m_chain;
189-
Chain::Notifications* m_notifications;
178+
std::shared_ptr<Chain::Notifications> m_notifications;
179+
};
180+
181+
class NotificationsHandlerImpl : public Handler
182+
{
183+
public:
184+
explicit NotificationsHandlerImpl(std::shared_ptr<Chain::Notifications> notifications)
185+
: m_proxy(std::make_shared<NotificationsProxy>(std::move(notifications)))
186+
{
187+
RegisterSharedValidationInterface(m_proxy);
188+
}
189+
~NotificationsHandlerImpl() override { disconnect(); }
190+
void disconnect() override
191+
{
192+
if (m_proxy) {
193+
UnregisterSharedValidationInterface(m_proxy);
194+
m_proxy.reset();
195+
}
196+
}
197+
std::shared_ptr<NotificationsProxy> m_proxy;
190198
};
191199

192200
class RpcHandlerImpl : public Handler
@@ -343,9 +351,9 @@ class ChainImpl : public Chain
343351
{
344352
::uiInterface.ShowProgress(title, progress, resume_possible);
345353
}
346-
std::unique_ptr<Handler> handleNotifications(Notifications& notifications) override
354+
std::unique_ptr<Handler> handleNotifications(std::shared_ptr<Notifications> notifications) override
347355
{
348-
return MakeUnique<NotificationsHandlerImpl>(*this, notifications);
356+
return MakeUnique<NotificationsHandlerImpl>(std::move(notifications));
349357
}
350358
void waitForNotificationsIfTipChanged(const uint256& old_tip) override
351359
{

src/interfaces/chain.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ class Chain
229229
};
230230

231231
//! Register handler for notifications.
232-
virtual std::unique_ptr<Handler> handleNotifications(Notifications& notifications) = 0;
232+
virtual std::unique_ptr<Handler> handleNotifications(std::shared_ptr<Notifications> notifications) = 0;
233233

234234
//! Wait for pending notifications to be processed unless block hash points to the current
235235
//! chain tip.

src/validationinterface.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ CMainSignals& GetMainSignals()
7575
return g_signals;
7676
}
7777

78-
void RegisterValidationInterface(CValidationInterface* pwalletIn) {
79-
ValidationInterfaceConnections& conns = g_signals.m_internals->m_connMainSignals[pwalletIn];
78+
void RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface> pwalletIn) {
79+
// Each connection captures pwalletIn to ensure that each callback is
80+
// executed before pwalletIn is destroyed. For more details see #18338.
81+
ValidationInterfaceConnections& conns = g_signals.m_internals->m_connMainSignals[pwalletIn.get()];
8082
conns.UpdatedBlockTip = g_signals.m_internals->UpdatedBlockTip.connect(std::bind(&CValidationInterface::UpdatedBlockTip, pwalletIn, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3));
8183
conns.TransactionAddedToMempool = g_signals.m_internals->TransactionAddedToMempool.connect(std::bind(&CValidationInterface::TransactionAddedToMempool, pwalletIn, std::placeholders::_1));
8284
conns.BlockConnected = g_signals.m_internals->BlockConnected.connect(std::bind(&CValidationInterface::BlockConnected, pwalletIn, std::placeholders::_1, std::placeholders::_2));
@@ -87,6 +89,18 @@ void RegisterValidationInterface(CValidationInterface* pwalletIn) {
8789
conns.NewPoWValidBlock = g_signals.m_internals->NewPoWValidBlock.connect(std::bind(&CValidationInterface::NewPoWValidBlock, pwalletIn, std::placeholders::_1, std::placeholders::_2));
8890
}
8991

92+
void RegisterValidationInterface(CValidationInterface* callbacks)
93+
{
94+
// Create a shared_ptr with a no-op deleter - CValidationInterface lifecycle
95+
// is managed by the caller.
96+
RegisterSharedValidationInterface({callbacks, [](CValidationInterface*){}});
97+
}
98+
99+
void UnregisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks)
100+
{
101+
UnregisterValidationInterface(callbacks.get());
102+
}
103+
90104
void UnregisterValidationInterface(CValidationInterface* pwalletIn) {
91105
if (g_signals.m_internals) {
92106
g_signals.m_internals->m_connMainSignals.erase(pwalletIn);

src/validationinterface.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ void RegisterValidationInterface(CValidationInterface* pwalletIn);
3030
void UnregisterValidationInterface(CValidationInterface* pwalletIn);
3131
/** Unregister all wallets from core */
3232
void UnregisterAllValidationInterfaces();
33+
34+
// Alternate registration functions that release a shared_ptr after the last
35+
// notification is sent. These are useful for race-free cleanup, since
36+
// unregistration is nonblocking and can return before the last notification is
37+
// processed.
38+
void RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks);
39+
void UnregisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks);
40+
3341
/**
3442
* Pushes a function to callback onto the notification queue, guaranteeing any
3543
* callbacks generated prior to now are finished when the function is called.
@@ -163,7 +171,7 @@ class CValidationInterface {
163171
* Notifies listeners that a block which builds directly on our current tip
164172
* has been received and connected to the headers tree, though not validated yet */
165173
virtual void NewPoWValidBlock(const CBlockIndex *pindex, const std::shared_ptr<const CBlock>& block) {};
166-
friend void ::RegisterValidationInterface(CValidationInterface*);
174+
friend void ::RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface>);
167175
friend void ::UnregisterValidationInterface(CValidationInterface*);
168176
friend void ::UnregisterAllValidationInterfaces();
169177
};
@@ -173,7 +181,7 @@ class CMainSignals {
173181
private:
174182
std::unique_ptr<MainSignalsInstance> m_internals;
175183

176-
friend void ::RegisterValidationInterface(CValidationInterface*);
184+
friend void ::RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface>);
177185
friend void ::UnregisterValidationInterface(CValidationInterface*);
178186
friend void ::UnregisterAllValidationInterfaces();
179187
friend void ::CallFunctionInValidationInterfaceQueue(std::function<void ()> func);

src/wallet/test/wallet_test_fixture.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ WalletTestingSetup::WalletTestingSetup(const std::string& chainName)
1010
{
1111
bool fFirstRun;
1212
m_wallet.LoadWallet(fFirstRun);
13-
m_wallet.handleNotifications();
14-
13+
m_chain_notifications_handler = m_chain->handleNotifications({ &m_wallet, [](CWallet*) {} });
1514
m_chain_client->registerRpcs();
1615
}

src/wallet/test/wallet_test_fixture.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct WalletTestingSetup: public TestingSetup {
2323
std::unique_ptr<interfaces::Chain> m_chain = interfaces::MakeChain(m_node);
2424
std::unique_ptr<interfaces::ChainClient> m_chain_client = interfaces::MakeWalletClient(*m_chain, {});
2525
CWallet m_wallet;
26+
std::unique_ptr<interfaces::Handler> m_chain_notifications_handler;
2627
};
2728

2829
#endif // BITCOIN_WALLET_TEST_WALLET_TEST_FIXTURE_H

src/wallet/wallet.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ bool AddWallet(const std::shared_ptr<CWallet>& wallet)
6262

6363
bool RemoveWallet(const std::shared_ptr<CWallet>& wallet)
6464
{
65-
LOCK(cs_wallets);
6665
assert(wallet);
66+
// Unregister with the validation interface which also drops shared ponters.
67+
wallet->m_chain_notifications_handler.reset();
68+
LOCK(cs_wallets);
6769
std::vector<std::shared_ptr<CWallet>>::iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet);
6870
if (i == vpwallets.end()) return false;
6971
vpwallets.erase(i);
@@ -105,13 +107,9 @@ static std::set<std::string> g_unloading_wallet_set;
105107
// Custom deleter for shared_ptr<CWallet>.
106108
static void ReleaseWallet(CWallet* wallet)
107109
{
108-
// Unregister and delete the wallet right after BlockUntilSyncedToCurrentChain
109-
// so that it's in sync with the current chainstate.
110110
const std::string name = wallet->GetName();
111111
wallet->WalletLogPrintf("Releasing wallet\n");
112-
wallet->BlockUntilSyncedToCurrentChain();
113112
wallet->Flush();
114-
wallet->m_chain_notifications_handler.reset();
115113
delete wallet;
116114
// Wallet is now released, notify UnloadWallet, if any.
117115
{
@@ -137,6 +135,7 @@ void UnloadWallet(std::shared_ptr<CWallet>&& wallet)
137135
// Notify the unload intent so that all remaining shared pointers are
138136
// released.
139137
wallet->NotifyUnload();
138+
140139
// Time to ditch our shared_ptr and wait for ReleaseWallet call.
141140
wallet.reset();
142141
{
@@ -4092,7 +4091,7 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain,
40924091
}
40934092

40944093
// Register with the validation interface. It's ok to do this after rescan since we're still holding locked_chain.
4095-
walletInstance->handleNotifications();
4094+
walletInstance->m_chain_notifications_handler = walletInstance->chain().handleNotifications(walletInstance);
40964095

40974096
walletInstance->SetBroadcastTransactions(gArgs.GetBoolArg("-walletbroadcast", DEFAULT_WALLETBROADCAST));
40984097

@@ -4105,11 +4104,6 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain,
41054104
return walletInstance;
41064105
}
41074106

4108-
void CWallet::handleNotifications()
4109-
{
4110-
m_chain_notifications_handler = m_chain->handleNotifications(*this);
4111-
}
4112-
41134107
void CWallet::postInitProcess()
41144108
{
41154109
auto locked_chain = chain().lock();

src/wallet/wallet.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ class WalletRescanReserver; //forward declarations for ScanForWalletTransactions
605605
/**
606606
* A CWallet maintains a set of transactions and balances, and provides the ability to create new transactions.
607607
*/
608-
class CWallet final : public WalletStorage, private interfaces::Chain::Notifications
608+
class CWallet final : public WalletStorage, public interfaces::Chain::Notifications
609609
{
610610
private:
611611
CKeyingMaterial vMasterKey GUARDED_BY(cs_wallet);
@@ -781,9 +781,6 @@ class CWallet final : public WalletStorage, private interfaces::Chain::Notificat
781781
/** Registered interfaces::Chain::Notifications handler. */
782782
std::unique_ptr<interfaces::Handler> m_chain_notifications_handler;
783783

784-
/** Register the wallet for chain notifications */
785-
void handleNotifications();
786-
787784
/** Interface for accessing chain state. */
788785
interfaces::Chain& chain() const { assert(m_chain); return *m_chain; }
789786

0 commit comments

Comments
 (0)