@@ -388,18 +388,22 @@ def on_version(self, message):
388
388
389
389
# Connection helper methods
390
390
391
- def wait_until (self , test_function , timeout = 60 ):
391
+ def wait_until (self , test_function_in , * , timeout = 60 , check_connected = True ):
392
+ def test_function ():
393
+ if check_connected :
394
+ assert self .is_connected
395
+ return test_function_in ()
396
+
392
397
wait_until (test_function , timeout = timeout , lock = mininode_lock , timeout_factor = self .timeout_factor )
393
398
394
399
def wait_for_disconnect (self , timeout = 60 ):
395
400
test_function = lambda : not self .is_connected
396
- self .wait_until (test_function , timeout = timeout )
401
+ self .wait_until (test_function , timeout = timeout , check_connected = False )
397
402
398
403
# Message receiving helper methods
399
404
400
405
def wait_for_tx (self , txid , timeout = 60 ):
401
406
def test_function ():
402
- assert self .is_connected
403
407
if not self .last_message .get ('tx' ):
404
408
return False
405
409
return self .last_message ['tx' ].tx .rehash () == txid
@@ -408,14 +412,12 @@ def test_function():
408
412
409
413
def wait_for_block (self , blockhash , timeout = 60 ):
410
414
def test_function ():
411
- assert self .is_connected
412
415
return self .last_message .get ("block" ) and self .last_message ["block" ].block .rehash () == blockhash
413
416
414
417
self .wait_until (test_function , timeout = timeout )
415
418
416
419
def wait_for_header (self , blockhash , timeout = 60 ):
417
420
def test_function ():
418
- assert self .is_connected
419
421
last_headers = self .last_message .get ('headers' )
420
422
if not last_headers :
421
423
return False
@@ -425,7 +427,6 @@ def test_function():
425
427
426
428
def wait_for_merkleblock (self , blockhash , timeout = 60 ):
427
429
def test_function ():
428
- assert self .is_connected
429
430
last_filtered_block = self .last_message .get ('merkleblock' )
430
431
if not last_filtered_block :
431
432
return False
@@ -437,9 +438,7 @@ def wait_for_getdata(self, hash_list, timeout=60):
437
438
"""Waits for a getdata message.
438
439
439
440
The object hashes in the inventory vector must match the provided hash_list."""
440
-
441
441
def test_function ():
442
- assert self .is_connected
443
442
last_data = self .last_message .get ("getdata" )
444
443
if not last_data :
445
444
return False
@@ -454,9 +453,7 @@ def wait_for_getheaders(self, timeout=60):
454
453
value must be explicitly cleared before calling this method, or this will return
455
454
immediately with success. TODO: change this method to take a hash value and only
456
455
return true if the correct block header has been requested."""
457
-
458
456
def test_function ():
459
- assert self .is_connected
460
457
return self .last_message .get ("getheaders" )
461
458
462
459
self .wait_until (test_function , timeout = timeout )
@@ -467,7 +464,6 @@ def wait_for_inv(self, expected_inv, timeout=60):
467
464
raise NotImplementedError ("wait_for_inv() will only verify the first inv object" )
468
465
469
466
def test_function ():
470
- assert self .is_connected
471
467
return self .last_message .get ("inv" ) and \
472
468
self .last_message ["inv" ].inv [0 ].type == expected_inv [0 ].type and \
473
469
self .last_message ["inv" ].inv [0 ].hash == expected_inv [0 ].hash
@@ -478,7 +474,7 @@ def wait_for_verack(self, timeout=60):
478
474
def test_function ():
479
475
return "verack" in self .last_message
480
476
481
- self .wait_until (test_function , timeout = timeout )
477
+ self .wait_until (test_function , timeout = timeout , check_connected = False )
482
478
483
479
# Message sending helper functions
484
480
@@ -491,7 +487,6 @@ def sync_with_ping(self, timeout=60):
491
487
self .send_message (msg_ping (nonce = self .ping_counter ))
492
488
493
489
def test_function ():
494
- assert self .is_connected
495
490
return self .last_message .get ("pong" ) and self .last_message ["pong" ].nonce == self .ping_counter
496
491
497
492
self .wait_until (test_function , timeout = timeout )
@@ -609,7 +604,11 @@ def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False,
609
604
self .send_message (msg_block (block = b ))
610
605
else :
611
606
self .send_message (msg_headers ([CBlockHeader (block ) for block in blocks ]))
612
- self .wait_until (lambda : blocks [- 1 ].sha256 in self .getdata_requests , timeout = timeout )
607
+ self .wait_until (
608
+ lambda : blocks [- 1 ].sha256 in self .getdata_requests ,
609
+ timeout = timeout ,
610
+ check_connected = success ,
611
+ )
613
612
614
613
if expect_disconnect :
615
614
self .wait_for_disconnect (timeout = timeout )
@@ -677,6 +676,6 @@ def wait_for_broadcast(self, txns, timeout=60):
677
676
The mempool should mark unbroadcast=False for these transactions.
678
677
"""
679
678
# Wait until invs have been received (and getdatas sent) for each txid.
680
- self .wait_until (lambda : set (self .tx_invs_received .keys ()) == set ([int (tx , 16 ) for tx in txns ]), timeout )
679
+ self .wait_until (lambda : set (self .tx_invs_received .keys ()) == set ([int (tx , 16 ) for tx in txns ]), timeout = timeout )
681
680
# Flush messages and wait for the getdatas to be processed
682
681
self .sync_with_ping ()
0 commit comments