Skip to content

Commit 62fa61f

Browse files
committed
refactor: remove the wallet folder if the restore fails
1 parent abbb7ec commit 62fa61f

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/wallet/wallet.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,14 @@ std::shared_ptr<CWallet> RestoreWallet(WalletContext& context, const std::string
379379
auto wallet_file = wallet_path / "wallet.dat";
380380
fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists);
381381

382-
return LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);
382+
auto wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);
383+
384+
if (!wallet) {
385+
fs::remove(wallet_file);
386+
fs::remove(wallet_path);
387+
}
388+
389+
return wallet;
383390
}
384391

385392
/** @defgroup mapWallet

test/functional/wallet_backup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def erase_three(self):
110110
os.remove(os.path.join(self.nodes[1].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
111111
os.remove(os.path.join(self.nodes[2].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
112112

113+
def restore_invalid_wallet(self):
114+
node = self.nodes[3]
115+
invalid_wallet_file = os.path.join(self.nodes[0].datadir, 'invalid_wallet_file.bak')
116+
open(invalid_wallet_file, 'a', encoding="utf8").write('invald wallet')
117+
wallet_name = "res0"
118+
not_created_wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
119+
error_message = "Wallet file verification failed. Failed to load database path '{}'. Data is not in recognized format.".format(not_created_wallet_file)
120+
assert_raises_rpc_error(-18, error_message, node.restorewallet, wallet_name, invalid_wallet_file)
121+
assert not os.path.exists(not_created_wallet_file)
122+
113123
def restore_nonexistent_wallet(self):
114124
node = self.nodes[3]
115125
nonexistent_wallet_file = os.path.join(self.nodes[0].datadir, 'nonexistent_wallet.bak')
@@ -125,6 +135,7 @@ def restore_wallet_existent_name(self):
125135
wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
126136
error_message = "Failed to create database path '{}'. Database already exists.".format(wallet_file)
127137
assert_raises_rpc_error(-36, error_message, node.restorewallet, wallet_name, backup_file)
138+
assert os.path.exists(wallet_file)
128139

129140
def init_three(self):
130141
self.init_wallet(node=0)
@@ -181,6 +192,7 @@ def run_test(self):
181192
##
182193
self.log.info("Restoring wallets on node 3 using backup files")
183194

195+
self.restore_invalid_wallet()
184196
self.restore_nonexistent_wallet()
185197

186198
backup_file_0 = os.path.join(self.nodes[0].datadir, 'wallet.bak')
@@ -191,6 +203,10 @@ def run_test(self):
191203
self.nodes[3].restorewallet("res1", backup_file_1)
192204
self.nodes[3].restorewallet("res2", backup_file_2)
193205

206+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res0"))
207+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res1"))
208+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res2"))
209+
194210
res0_rpc = self.nodes[3].get_wallet_rpc("res0")
195211
res1_rpc = self.nodes[3].get_wallet_rpc("res1")
196212
res2_rpc = self.nodes[3].get_wallet_rpc("res2")

0 commit comments

Comments
 (0)