Skip to content

Commit ea99c55

Browse files
committed
add username arg where necessary, more tests
1 parent 63b6bd6 commit ea99c55

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

testgres/testgres.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,17 @@ def __init__(self,
164164
parent_node,
165165
dbname,
166166
host="127.0.0.1",
167-
user=None,
167+
username=None,
168168
password=None):
169169

170170
# Use default user if not specified
171-
user = user or _default_username()
171+
username = username or _default_username()
172172

173173
self.parent_node = parent_node
174174

175175
self.connection = pglib.connect(
176176
database=dbname,
177-
user=user,
177+
user=username,
178178
port=parent_node.port,
179179
host=host,
180180
password=password)
@@ -251,10 +251,17 @@ class NodeBackup(object):
251251
def log_file(self):
252252
return os.path.join(self.base_dir, BACKUP_LOG_FILE)
253253

254-
def __init__(self, node, base_dir=None, xlog_method=DEFAULT_XLOG_METHOD):
254+
def __init__(self,
255+
node,
256+
base_dir=None,
257+
username=None,
258+
xlog_method=DEFAULT_XLOG_METHOD):
259+
255260
if not node.status():
256261
raise BackupException('Node must be running')
257262

263+
# set default arguments
264+
username = username or _default_username()
258265
base_dir = base_dir or tempfile.mkdtemp()
259266

260267
# create directory if needed
@@ -269,7 +276,8 @@ def __init__(self, node, base_dir=None, xlog_method=DEFAULT_XLOG_METHOD):
269276
_params = [
270277
"-D{}".format(data_dir),
271278
"-p{}".format(node.port),
272-
"-X", xlog_method
279+
"-U{}".format(username),
280+
"-X{}".format(xlog_method)
273281
]
274282
_execute_utility("pg_basebackup", _params, self.log_file)
275283

@@ -782,7 +790,7 @@ def dump(self, dbname, filename=None):
782790

783791
return filename
784792

785-
def restore(self, dbname, filename):
793+
def restore(self, dbname, filename, username=None):
786794
"""
787795
Restore database from pg_dump's file.
788796
@@ -791,9 +799,9 @@ def restore(self, dbname, filename):
791799
filename: database dump taken by pg_dump (str).
792800
"""
793801

794-
self.psql(dbname, filename=filename)
802+
self.psql(dbname=dbname, filename=filename, username=username)
795803

796-
def poll_query_until(self, dbname, query):
804+
def poll_query_until(self, dbname, query, username=None, max_attempts=60, sleep_time=1):
797805
"""
798806
Run a query once a second until it returs True.
799807
@@ -802,16 +810,17 @@ def poll_query_until(self, dbname, query):
802810
query: query to be executed (str).
803811
"""
804812

805-
max_attemps = 60
806813
attemps = 0
814+
while attemps < max_attempts:
815+
res = self.execute(dbname=dbname,
816+
query=query,
817+
username=username,
818+
commit=True)
807819

808-
while attemps < max_attemps:
809-
ret = self.safe_psql(dbname, query)
810-
# TODO: fix psql so that it returns result without newline
811-
if ret == six.b("t\n"):
812-
return
820+
if res[0][0]:
821+
return # done
813822

814-
time.sleep(1)
823+
time.sleep(sleep_time)
815824
attemps += 1
816825

817826
raise QueryException('Query timeout')
@@ -836,19 +845,21 @@ def execute(self, dbname, query, username=None, commit=False):
836845
node_con.commit()
837846
return res
838847

839-
def backup(self, xlog_method=DEFAULT_XLOG_METHOD):
848+
def backup(self, username=None, xlog_method=DEFAULT_XLOG_METHOD):
840849
"""
841850
Perform pg_basebackup.
842851
843852
Args:
844-
xlog_method: a method for collecting the logs ('fetch' | 'stream')
853+
username: database user name (str).
854+
xlog_method: a method for collecting the logs ('fetch' | 'stream').
845855
846856
Returns:
847857
A smart object of type NodeBackup.
848858
"""
849859

850-
# NodeBackup will handle this
851-
return NodeBackup(self, xlog_method=xlog_method)
860+
return NodeBackup(node=self,
861+
username=username,
862+
xlog_method=xlog_method)
852863

853864
def pgbench_init(self, dbname='postgres', scale=1, options=[]):
854865
"""
@@ -905,7 +916,9 @@ def connect(self, dbname='postgres', username=None):
905916
An instance of NodeConnection.
906917
"""
907918

908-
return NodeConnection(parent_node=self, dbname=dbname, user=username)
919+
return NodeConnection(parent_node=self,
920+
dbname=dbname,
921+
username=username)
909922

910923

911924
def _default_username():

testgres/tests/test_simple.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_backup_simple(self):
160160
master.psql('postgres',
161161
'create table test as select generate_series(1, 4) i')
162162

163-
with master.backup('stream') as backup:
163+
with master.backup(xlog_method='stream') as backup:
164164
with backup.spawn_primary('slave') as slave:
165165
slave.start()
166166
res = slave.execute('postgres',
@@ -171,12 +171,12 @@ def test_backup_multiple(self):
171171
with get_new_node('node') as node:
172172
node.init(allow_streaming=True).start()
173173

174-
with node.backup('fetch') as backup1, \
175-
node.backup('fetch') as backup2:
174+
with node.backup(xlog_method='fetch') as backup1, \
175+
node.backup(xlog_method='fetch') as backup2:
176176

177177
self.assertNotEqual(backup1.base_dir, backup2.base_dir)
178178

179-
with node.backup('fetch') as backup:
179+
with node.backup(xlog_method='fetch') as backup:
180180
with backup.spawn_primary('node1', destroy=False) as node1, \
181181
backup.spawn_primary('node2', destroy=False) as node2:
182182

@@ -186,7 +186,7 @@ def test_backup_exhaust(self):
186186
with get_new_node('node') as node:
187187
node.init(allow_streaming=True).start()
188188

189-
with node.backup('fetch') as backup:
189+
with node.backup(xlog_method='fetch') as backup:
190190
with backup.spawn_primary('node1') as node1:
191191
pass
192192

@@ -264,6 +264,19 @@ def test_users(self):
264264
value = node.safe_psql('postgres', 'select 1', username='test_user')
265265
self.assertEqual(value, six.b('1\n'))
266266

267+
def test_poll_query_until(self):
268+
with get_new_node('master') as node:
269+
node.init().start()
270+
271+
get_time = 'select extract(epoch from now())'
272+
check_time = 'select extract(epoch from now()) - {} >= 5'
273+
274+
start_time = node.execute('postgres', get_time)[0][0]
275+
node.poll_query_until('postgres', check_time.format(start_time))
276+
end_time = node.execute('postgres', get_time)[0][0]
277+
278+
self.assertTrue(end_time - start_time >= 5)
279+
267280
def test_logging(self):
268281
regex = re.compile('.+?LOG:.*')
269282
logfile = tempfile.NamedTemporaryFile('w', delete=True)

0 commit comments

Comments
 (0)