1
1
from .. import COMPOSE_FILE , Smoketest , requires_docker , spacetime
2
2
from ..docker import DockerManager
3
3
4
- import re
5
4
import time
6
5
from typing import Callable
7
6
import unittest
8
7
9
- def get_int (text ):
10
- digits = re .search (r'\d+' , text )
11
- if digits is None :
12
- raise Exception ("no numbers found in string" )
13
- return int (digits .group ())
14
-
15
8
def retry (func : Callable , max_retries : int = 3 , retry_delay : int = 2 ):
16
9
"""Retry a function on failure with delay."""
17
10
for attempt in range (1 , max_retries + 1 ):
@@ -25,6 +18,21 @@ def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
25
18
print ("Max retries reached. Skipping the exception." )
26
19
return False
27
20
21
+ def parse_sql_result (res : str ) -> list [dict ]:
22
+ """Parse tabular output from an SQL query into a list of dicts."""
23
+ lines = res .splitlines ()
24
+ headers = lines [0 ].split ('|' ) if '|' in lines [0 ] else [lines [0 ]]
25
+ headers = [header .strip () for header in headers ]
26
+ rows = []
27
+ for row in lines [2 :]:
28
+ cols = [col .strip () for col in row .split ('|' )]
29
+ rows .append (dict (zip (headers , cols )))
30
+ return rows
31
+
32
+ def int_vals (rows : list [dict ]) -> list [dict ]:
33
+ """For all dicts in list, cast all values in dict to int."""
34
+ return [{k : int (v ) for k , v in row .items ()} for row in rows ]
35
+
28
36
class Cluster :
29
37
"""Manages leader-related operations and state for SpaceTime database cluster."""
30
38
@@ -35,56 +43,47 @@ def __init__(self, docker_manager, smoketest: Smoketest):
35
43
# Ensure all containers are up.
36
44
self .docker .compose ("up" , "-d" )
37
45
38
- def read_controldb (self , sql ):
39
- """Helper method to read from control database."""
40
- return self .test .spacetime ("sql" , "spacetime-control" , sql )
46
+ def sql (self , sql : str ) -> list [dict ]:
47
+ """Query the test database."""
48
+ res = self .test .sql (sql )
49
+ return parse_sql_result (str (res ))
50
+
51
+ def read_controldb (self , sql : str ) -> list [dict ]:
52
+ """Query the control database."""
53
+ res = self .test .spacetime ("sql" , "spacetime-control" , sql )
54
+ return parse_sql_result (str (res ))
41
55
42
56
def get_db_id (self ):
43
57
"""Query database ID."""
44
58
sql = f"select id from database where database_identity=0x{ self .test .database_identity } "
45
- db_id_tb = self .read_controldb (sql )
46
- return get_int (db_id_tb )
47
-
59
+ res = self .read_controldb (sql )
60
+ return int (res [0 ]['id' ])
48
61
49
62
def get_all_replicas (self ):
50
63
"""Get all replica nodes in the cluster."""
51
64
database_id = self .get_db_id ()
52
65
sql = f"select id, node_id from replica where database_id={ database_id } "
53
- replica_tb = self .read_controldb (sql )
54
- replicas = []
55
- for line in str (replica_tb ).splitlines ()[2 :]:
56
- replica_id , node_id = line .split ('|' )
57
- replicas .append ({
58
- 'replica_id' : int (replica_id ),
59
- 'node_id' : int (node_id )
60
- })
61
- return replicas
66
+ return int_vals (self .read_controldb (sql ))
62
67
63
68
def get_leader_info (self ):
64
69
"""Get current leader's node information including ID, hostname, and container ID."""
65
70
66
71
database_id = self .get_db_id ()
67
- # Query leader replica ID
68
- sql = f"select leader from replication_state where database_id={ database_id } "
69
- leader_tb = self .read_controldb (sql )
70
- leader_id = get_int (leader_tb )
71
-
72
- # Query leader node ID
73
- sql = f"select node_id from replica where id={ leader_id } "
74
- leader_node_tb = self .read_controldb (sql )
75
- leader_node_id = get_int (leader_node_tb )
76
-
77
- # Query leader hostname
78
- sql = f"select network_addr from node_v2 where id={ leader_node_id } "
79
- leader_host_tb = str (self .read_controldb (sql ))
80
- lines = leader_host_tb .splitlines ()
72
+ sql = f""" \
73
+ select node_v2.id, node_v2.network_addr from node_v2 \
74
+ join replica on replica.node_id=node_v2.id \
75
+ join replication_state on replication_state.leader=replica.id \
76
+ where replication_state.database_id={ database_id } \
77
+ """
78
+ rows = self .read_controldb (sql )
79
+ if not rows :
80
+ raise Exception ("Could not find current leader's node" )
81
81
82
+ leader_node_id = int (rows [0 ]['id' ])
82
83
hostname = ""
83
- if len (lines ) == 3 : # actual row starts from 3rd line
84
- leader_row = lines [2 ]
85
- if "(some =" in leader_row :
86
- address = leader_row .split ('"' )[1 ]
87
- hostname = address .split (':' )[0 ]
84
+ if "(some =" in rows [0 ]['network_addr' ]:
85
+ address = rows [0 ]['network_addr' ].split ('"' )[1 ]
86
+ hostname = address .split (':' )[0 ]
88
87
89
88
# Find container ID
90
89
container_id = ""
@@ -114,15 +113,16 @@ def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2)
114
113
time .sleep (delay )
115
114
return None
116
115
117
- def ensure_leader_health (self , id , wait_time = 2 ):
116
+ def ensure_leader_health (self , id ):
118
117
"""Verify leader is healthy by inserting a row."""
119
- if wait_time :
120
- time .sleep (wait_time )
121
118
122
119
retry (lambda : self .test .call ("start" , id , 1 ))
123
- add_table = str ( self .test . sql (f"SELECT id FROM counter where id={ id } " ) )
124
- if str ( id ) not in add_table :
120
+ rows = self .sql (f"select id from counter where id={ id } " )
121
+ if len ( rows ) < 1 or int ( rows [ 0 ][ 'id' ]) != id :
125
122
raise ValueError (f"Could not find { id } in counter table" )
123
+ # Wait for at least one tick to ensure buffers are flushed.
124
+ # TODO: Replace with confirmed read.
125
+ time .sleep (0.6 )
126
126
127
127
128
128
def fail_leader (self , action = 'kill' ):
@@ -247,31 +247,42 @@ def start(self, id: int, count: int):
247
247
"""Send a message to the database."""
248
248
retry (lambda : self .call ("start" , id , count ))
249
249
250
+ def collect_counter_rows (self ):
251
+ return int_vals (self .cluster .sql ("select * from counter" ))
252
+
253
+
250
254
class LeaderElection (ReplicationTest ):
251
255
def test_leader_election_in_loop (self ):
252
256
"""This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
253
257
iterations = 5
254
258
row_ids = [101 + i for i in range (iterations * 2 )]
255
259
for (first_id , second_id ) in zip (row_ids [::2 ], row_ids [1 ::2 ]):
256
260
cur_leader = self .cluster .wait_for_leader_change (None )
261
+ print (f"ensure leader health { first_id } " )
257
262
self .cluster .ensure_leader_health (first_id )
258
263
259
- print ("killing current leader: {}" , cur_leader )
264
+ print (f "killing current leader: { cur_leader } " )
260
265
container_id = self .cluster .fail_leader ()
261
266
262
267
self .assertIsNotNone (container_id )
263
268
264
269
next_leader = self .cluster .wait_for_leader_change (cur_leader )
265
270
self .assertNotEqual (cur_leader , next_leader )
266
271
# this check if leader election happened
272
+ print (f"ensure_leader_health { second_id } " )
267
273
self .cluster .ensure_leader_health (second_id )
268
274
# restart the old leader, so that we can maintain quorum for next iteration
275
+ print (f"reconnect leader { container_id } " )
269
276
self .cluster .restore_leader (container_id , 'start' )
270
277
271
- # verify if all past rows are present in new leader
272
- for row_id in row_ids :
273
- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
274
- self .assertIn (f"{ row_id } " , str (table ))
278
+ # Ensure we have a current leader
279
+ last_row_id = row_ids [- 1 ] + 1
280
+ self .cluster .ensure_leader_health (row_ids [- 1 ] + 1 )
281
+ row_ids .append (last_row_id )
282
+
283
+ # Verify that all inserted rows are present
284
+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
285
+ self .assertEqual (set (stored_row_ids ), set (row_ids ))
275
286
276
287
class LeaderDisconnect (ReplicationTest ):
277
288
def test_leader_c_disconnect_in_loop (self ):
@@ -300,12 +311,15 @@ def test_leader_c_disconnect_in_loop(self):
300
311
# restart the old leader, so that we can maintain quorum for next iteration
301
312
print (f"reconnect leader { container_id } " )
302
313
self .cluster .restore_leader (container_id , 'connect' )
303
- time .sleep (1 )
304
314
305
- # verify if all past rows are present in new leader
306
- for row_id in row_ids :
307
- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
308
- self .assertIn (f"{ row_id } " , str (table ))
315
+ # Ensure we have a current leader
316
+ last_row_id = row_ids [- 1 ] + 1
317
+ self .cluster .ensure_leader_health (last_row_id )
318
+ row_ids .append (last_row_id )
319
+
320
+ # Verify that all inserted rows are present
321
+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
322
+ self .assertEqual (set (stored_row_ids ), set (row_ids ))
309
323
310
324
311
325
@unittest .skip ("drain_node not yet supported" )
@@ -342,18 +356,16 @@ def test_prefer_leader(self):
342
356
if replica ['node_id' ] != cur_leader_node_id :
343
357
prefer_replica = replica
344
358
break
345
- prefer_replica_id = prefer_replica ['replica_id ' ]
359
+ prefer_replica_id = prefer_replica ['id ' ]
346
360
self .spacetime ("call" , "spacetime-control" , "prefer_leader" , f"{ prefer_replica_id } " )
347
361
348
362
next_leader_node_id = self .cluster .wait_for_leader_change (cur_leader_node_id )
349
363
self .cluster .ensure_leader_health (402 )
350
364
self .assertEqual (prefer_replica ['node_id' ], next_leader_node_id )
351
365
352
-
353
366
# verify if all past rows are present in new leader
354
- for row_id in [401 , 402 ]:
355
- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
356
- self .assertIn (f"{ row_id } " , str (table ))
367
+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
368
+ self .assertEqual (set (stored_row_ids ), set ([401 , 402 ]))
357
369
358
370
359
371
class ManyTransactions (ReplicationTest ):
0 commit comments