@@ -378,18 +378,22 @@ def on_version(self, message):
378
378
379
379
# Connection helper methods
380
380
381
- def wait_until (self , test_function , timeout = 60 ):
381
+ def wait_until (self , test_function_in , * , timeout = 60 , check_connected = True ):
382
+ def test_function ():
383
+ if check_connected :
384
+ assert self .is_connected
385
+ return test_function_in ()
386
+
382
387
wait_until (test_function , timeout = timeout , lock = mininode_lock , timeout_factor = self .timeout_factor )
383
388
384
389
def wait_for_disconnect (self , timeout = 60 ):
385
390
test_function = lambda : not self .is_connected
386
- self .wait_until (test_function , timeout = timeout )
391
+ self .wait_until (test_function , timeout = timeout , check_connected = False )
387
392
388
393
# Message receiving helper methods
389
394
390
395
def wait_for_tx (self , txid , timeout = 60 ):
391
396
def test_function ():
392
- assert self .is_connected
393
397
if not self .last_message .get ('tx' ):
394
398
return False
395
399
return self .last_message ['tx' ].tx .rehash () == txid
@@ -398,14 +402,12 @@ def test_function():
398
402
399
403
def wait_for_block (self , blockhash , timeout = 60 ):
400
404
def test_function ():
401
- assert self .is_connected
402
405
return self .last_message .get ("block" ) and self .last_message ["block" ].block .rehash () == blockhash
403
406
404
407
self .wait_until (test_function , timeout = timeout )
405
408
406
409
def wait_for_header (self , blockhash , timeout = 60 ):
407
410
def test_function ():
408
- assert self .is_connected
409
411
last_headers = self .last_message .get ('headers' )
410
412
if not last_headers :
411
413
return False
@@ -415,7 +417,6 @@ def test_function():
415
417
416
418
def wait_for_merkleblock (self , blockhash , timeout = 60 ):
417
419
def test_function ():
418
- assert self .is_connected
419
420
last_filtered_block = self .last_message .get ('merkleblock' )
420
421
if not last_filtered_block :
421
422
return False
@@ -427,9 +428,7 @@ def wait_for_getdata(self, hash_list, timeout=60):
427
428
"""Waits for a getdata message.
428
429
429
430
The object hashes in the inventory vector must match the provided hash_list."""
430
-
431
431
def test_function ():
432
- assert self .is_connected
433
432
last_data = self .last_message .get ("getdata" )
434
433
if not last_data :
435
434
return False
@@ -444,9 +443,7 @@ def wait_for_getheaders(self, timeout=60):
444
443
value must be explicitly cleared before calling this method, or this will return
445
444
immediately with success. TODO: change this method to take a hash value and only
446
445
return true if the correct block header has been requested."""
447
-
448
446
def test_function ():
449
- assert self .is_connected
450
447
return self .last_message .get ("getheaders" )
451
448
452
449
self .wait_until (test_function , timeout = timeout )
@@ -457,7 +454,6 @@ def wait_for_inv(self, expected_inv, timeout=60):
457
454
raise NotImplementedError ("wait_for_inv() will only verify the first inv object" )
458
455
459
456
def test_function ():
460
- assert self .is_connected
461
457
return self .last_message .get ("inv" ) and \
462
458
self .last_message ["inv" ].inv [0 ].type == expected_inv [0 ].type and \
463
459
self .last_message ["inv" ].inv [0 ].hash == expected_inv [0 ].hash
@@ -468,7 +464,7 @@ def wait_for_verack(self, timeout=60):
468
464
def test_function ():
469
465
return self .message_count ["verack" ]
470
466
471
- self .wait_until (test_function , timeout = timeout )
467
+ self .wait_until (test_function , timeout = timeout , check_connected = False )
472
468
473
469
# Message sending helper functions
474
470
@@ -481,7 +477,6 @@ def sync_with_ping(self, timeout=60):
481
477
self .send_message (msg_ping (nonce = self .ping_counter ))
482
478
483
479
def test_function ():
484
- assert self .is_connected
485
480
return self .last_message .get ("pong" ) and self .last_message ["pong" ].nonce == self .ping_counter
486
481
487
482
self .wait_until (test_function , timeout = timeout )
@@ -599,7 +594,11 @@ def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False,
599
594
self .send_message (msg_block (block = b ))
600
595
else :
601
596
self .send_message (msg_headers ([CBlockHeader (block ) for block in blocks ]))
602
- self .wait_until (lambda : blocks [- 1 ].sha256 in self .getdata_requests , timeout = timeout )
597
+ self .wait_until (
598
+ lambda : blocks [- 1 ].sha256 in self .getdata_requests ,
599
+ timeout = timeout ,
600
+ check_connected = success ,
601
+ )
603
602
604
603
if expect_disconnect :
605
604
self .wait_for_disconnect (timeout = timeout )
@@ -667,6 +666,6 @@ def wait_for_broadcast(self, txns, timeout=60):
667
666
The mempool should mark unbroadcast=False for these transactions.
668
667
"""
669
668
# Wait until invs have been received (and getdatas sent) for each txid.
670
- self .wait_until (lambda : set (self .tx_invs_received .keys ()) == set ([int (tx , 16 ) for tx in txns ]), timeout )
669
+ self .wait_until (lambda : set (self .tx_invs_received .keys ()) == set ([int (tx , 16 ) for tx in txns ]), timeout = timeout )
671
670
# Flush messages and wait for the getdatas to be processed
672
671
self .sync_with_ping ()
0 commit comments