@@ -87,6 +87,7 @@ def check(self, txid=None, amount=None, confirmation_height=None):
8787 assert_equal (len (txs ), self .expected_txs )
8888
8989 addresses = self .node .listreceivedbyaddress (minconf = 0 , include_watchonly = True , address_filter = self .address ['address' ])
90+
9091 if self .expected_txs :
9192 assert_equal (len (addresses [0 ]["txids" ]), self .expected_txs )
9293
@@ -98,13 +99,18 @@ def check(self, txid=None, amount=None, confirmation_height=None):
9899 assert_equal (tx ["category" ], "receive" )
99100 assert_equal (tx ["label" ], self .label )
100101 assert_equal (tx ["txid" ], txid )
101- assert_equal (tx ["confirmations" ], 1 + current_height - confirmation_height )
102- assert "trusted" not in tx
102+
103+ # If no confirmation height is given, the tx is still in the
104+ # mempool.
105+ confirmations = (1 + current_height - confirmation_height ) if confirmation_height else 0
106+ assert_equal (tx ["confirmations" ], confirmations )
107+ if confirmations :
108+ assert "trusted" not in tx
103109
104110 address , = [ad for ad in addresses if txid in ad ["txids" ]]
105111 assert_equal (address ["address" ], self .address ["address" ])
106112 assert_equal (address ["amount" ], self .expected_balance )
107- assert_equal (address ["confirmations" ], 1 + current_height - confirmation_height )
113+ assert_equal (address ["confirmations" ], confirmations )
108114 # Verify the transaction is correctly marked watchonly depending on
109115 # whether the transaction pays to an imported public key or
110116 # imported private key. The test setup ensures that transaction
@@ -162,11 +168,12 @@ def setup_network(self):
162168 self .import_deterministic_coinbase_privkeys ()
163169 self .stop_nodes ()
164170
165- self .start_nodes ()
171+ self .
start_nodes (
extra_args = [[ "[email protected] " ]] * self . num_nodes )
166172 for i in range (1 , self .num_nodes ):
167173 self .connect_nodes (i , 0 )
168174
169175 def run_test (self ):
176+
170177 # Create one transaction on node 0 with a unique amount for
171178 # each possible type of wallet import RPC.
172179 for i , variant in enumerate (IMPORT_VARIANTS ):
@@ -207,7 +214,7 @@ def run_test(self):
207214 variant .check ()
208215
209216 # Create new transactions sending to each address.
210- for i , variant in enumerate ( IMPORT_VARIANTS ) :
217+ for variant in IMPORT_VARIANTS :
211218 variant .sent_amount = get_rand_amount ()
212219 variant .sent_txid = self .nodes [0 ].sendtoaddress (variant .address ["address" ], variant .sent_amount )
213220 self .generate (self .nodes [0 ], 1 ) # Generate one block for each send
@@ -223,6 +230,46 @@ def run_test(self):
223230 variant .expected_txs += 1
224231 variant .check (variant .sent_txid , variant .sent_amount , variant .confirmation_height )
225232
233+ self .log .info ('Test that the mempool is rescanned as well if the rescan parameter is set to true' )
234+
235+ # The late timestamp and pruned variants are not necessary when testing mempool rescan
236+ mempool_variants = [variant for variant in IMPORT_VARIANTS if variant .rescan != Rescan .late_timestamp and not variant .prune ]
237+ # No further blocks are mined so the timestamp will stay the same
238+ timestamp = self .nodes [0 ].getblockheader (self .nodes [0 ].getbestblockhash ())["time" ]
239+
240+ # Create one transaction on node 0 with a unique amount for
241+ # each possible type of wallet import RPC.
242+ for i , variant in enumerate (mempool_variants ):
243+ variant .label = "mempool label {} {}" .format (i , variant )
244+ variant .address = self .nodes [1 ].getaddressinfo (self .nodes [1 ].getnewaddress (
245+ label = variant .label ,
246+ address_type = variant .address_type .value ,
247+ ))
248+ variant .key = self .nodes [1 ].dumpprivkey (variant .address ["address" ])
249+ variant .initial_amount = get_rand_amount ()
250+ variant .initial_txid = self .nodes [0 ].sendtoaddress (variant .address ["address" ], variant .initial_amount )
251+ variant .confirmation_height = 0
252+ variant .timestamp = timestamp
253+
254+ assert_equal (len (self .nodes [0 ].getrawmempool ()), len (mempool_variants ))
255+ self .sync_mempools ()
256+
257+ # For each variation of wallet key import, invoke the import RPC and
258+ # check the results from getbalance and listtransactions.
259+ for variant in mempool_variants :
260+ self .log .info ('Run import for mempool variant {}' .format (variant ))
261+ expect_rescan = variant .rescan == Rescan .yes
262+ variant .node = self .nodes [2 + IMPORT_NODES .index (ImportNode (variant .prune , expect_rescan ))]
263+ variant .do_import (variant .timestamp )
264+ if expect_rescan :
265+ variant .expected_balance = variant .initial_amount
266+ variant .expected_txs = 1
267+ variant .check (variant .initial_txid , variant .initial_amount )
268+ else :
269+ variant .expected_balance = 0
270+ variant .expected_txs = 0
271+ variant .check ()
272+
226273
227274if __name__ == "__main__" :
228275 ImportRescanTest ().main ()
0 commit comments