1414
1515logger = logging .getLogger (__name__ )
1616
17- create_table_query = """
18- CREATE TABLE IF NOT EXISTS users(
19- id serial PRIMARY KEY,
20- name VARCHAR (50),
21- password VARCHAR (50),
22- email VARCHAR (355),
23- created_on TIMESTAMP,
24- last_login TIMESTAMP
25- );
26- """
27-
28- create_proc_query = """\
29- CREATE OR REPLACE FUNCTION test_proc(candidate VARCHAR(70))
30- RETURNS text AS $$
31- BEGIN
32- RETURN(SELECT name FROM users where email = candidate);
33- END;
34- $$ LANGUAGE plpgsql;
35- """
36-
37- drop_proc_query = "DROP FUNCTION IF EXISTS test_proc(VARCHAR(70));"
38-
39- db = psycopg2 .connect (host = testenv ['postgresql_host' ], port = testenv ['postgresql_port' ],
40- user = testenv ['postgresql_user' ], password = testenv ['postgresql_pw' ],
41- database = testenv ['postgresql_db' ])
42-
43- cursor = db .cursor ()
44- cursor .execute (create_table_query )
45- cursor .execute (drop_proc_query )
46- cursor .execute (create_proc_query )
47- db .commit ()
48- cursor .close ()
49- db .close ()
50-
5117
5218class TestPsycoPG2 (unittest .TestCase ):
5319 def setUp (self ):
5420 self .db = psycopg2 .connect (host = testenv ['postgresql_host' ], port = testenv ['postgresql_port' ],
5521 user = testenv ['postgresql_user' ], password = testenv ['postgresql_pw' ],
5622 database = testenv ['postgresql_db' ])
23+
24+ database_setup_query = """
25+ DROP TABLE IF EXISTS users;
26+ CREATE TABLE users(
27+ id serial PRIMARY KEY,
28+ name VARCHAR (50),
29+ password VARCHAR (50),
30+ email VARCHAR (355),
31+ created_on TIMESTAMP,
32+ last_login TIMESTAMP
33+ );
34+ INSERT INTO users(name, email) VALUES('kermit', '[email protected] '); 35+ DROP FUNCTION IF EXISTS test_proc(VARCHAR(70));
36+ CREATE FUNCTION test_proc(candidate VARCHAR(70))
37+ RETURNS text AS $$
38+ BEGIN
39+ RETURN(SELECT name FROM users where email = candidate);
40+ END;
41+ $$ LANGUAGE plpgsql;
42+ """
43+ cursor = self .db .cursor ()
44+ cursor .execute (database_setup_query )
45+ self .db .commit ()
46+ cursor .close ()
47+
48+
5749 self .cursor = self .db .cursor ()
5850 self .recorder = tracer .recorder
5951 self .recorder .clear_spans ()
6052 tracer .cur_ctx = None
6153
6254 def tearDown (self ):
63- """ Do nothing for now """
64- return None
55+ if self .cursor and not self .cursor .connection .closed :
56+ self .cursor .close ()
57+ if self .db and not self .db .closed :
58+ self .db .close ()
6559
6660 def test_vanilla_query (self ):
6761 self .assertTrue (psycopg2 .extras .register_uuid (None , self .db ))
6862 self .assertTrue (psycopg2 .extras .register_uuid (None , self .db .cursor ()))
6963
7064 self .cursor .execute ("""SELECT * from users""" )
65+ affected_rows = self .cursor .rowcount
66+ self .assertEqual (1 , affected_rows )
7167 result = self .cursor .fetchone ()
7268
7369 self .assertEqual (6 , len (result ))
@@ -78,9 +74,13 @@ def test_vanilla_query(self):
7874 def test_basic_query (self ):
7975 with tracer .start_active_span ('test' ):
8076 self .cursor .execute ("""SELECT * from users""" )
81- self .cursor .fetchone ()
77+ affected_rows = self .cursor .rowcount
78+ result = self .cursor .fetchone ()
8279 self .db .commit ()
8380
81+ self .assertEqual (1 , affected_rows )
82+ self .assertEqual (6 , len (result ))
83+
8484 spans = self .recorder .queued_spans ()
8585 self .assertEqual (2 , len (spans ))
8686
@@ -102,6 +102,9 @@ def test_basic_query(self):
102102 def test_basic_insert (self ):
103103 with tracer .start_active_span ('test' ):
104104 self .
cursor .
execute (
"""INSERT INTO users(name, email) VALUES(%s, %s)""" , (
'beaker' ,
'[email protected] ' ))
105+ affected_rows = self .cursor .rowcount
106+
107+ self .assertEqual (1 , affected_rows )
105108
106109 spans = self .recorder .queued_spans ()
107110 self .assertEqual (2 , len (spans ))
@@ -123,10 +126,13 @@ def test_basic_insert(self):
123126
124127 def test_executemany (self ):
125128 with tracer .start_active_span ('test' ):
126- result = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
127- 129+ self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
130+ 131+ affected_rows = self .cursor .rowcount
128132 self .db .commit ()
129133
134+ self .assertEqual (2 , affected_rows )
135+
130136 spans = self .recorder .queued_spans ()
131137 self .assertEqual (2 , len (spans ))
132138
@@ -147,9 +153,9 @@ def test_executemany(self):
147153
148154 def test_call_proc (self ):
149155 with tracer .start_active_span ('test' ):
150- result = self .cursor .callproc ('test_proc' , ('beaker' ,))
156+ callproc_result = self .cursor .callproc ('test_proc' , ('beaker' ,))
151157
152- self .assertIsInstance (result , tuple )
158+ self .assertIsInstance (callproc_result , tuple )
153159
154160 spans = self .recorder .queued_spans ()
155161 self .assertEqual (2 , len (spans ))
@@ -170,14 +176,16 @@ def test_call_proc(self):
170176 self .assertEqual (db_span .data ["pg" ]["port" ], testenv ['postgresql_port' ])
171177
172178 def test_error_capture (self ):
173- result = None
179+ affected_rows = result = None
174180 try :
175181 with tracer .start_active_span ('test' ):
176- result = self .cursor .execute ("""SELECT * from blah""" )
182+ self .cursor .execute ("""SELECT * from blah""" )
183+ affected_rows = self .cursor .rowcount
177184 self .cursor .fetchone ()
178185 except Exception :
179186 pass
180187
188+ self .assertIsNone (affected_rows )
181189 self .assertIsNone (result )
182190
183191 spans = self .recorder .queued_spans ()
@@ -246,6 +254,11 @@ def test_connect_cursor_ctx_mgr(self):
246254 with self .db as connection :
247255 with connection .cursor () as cursor :
248256 cursor .execute ("""SELECT * from users""" )
257+ affected_rows = cursor .rowcount
258+ result = cursor .fetchone ()
259+
260+ self .assertEqual (1 , affected_rows )
261+ self .assertEqual (6 , len (result ))
249262
250263 spans = self .recorder .queued_spans ()
251264 self .assertEqual (2 , len (spans ))
@@ -270,6 +283,11 @@ def test_connect_ctx_mgr(self):
270283 with self .db as connection :
271284 cursor = connection .cursor ()
272285 cursor .execute ("""SELECT * from users""" )
286+ affected_rows = cursor .rowcount
287+ result = cursor .fetchone ()
288+
289+ self .assertEqual (1 , affected_rows )
290+ self .assertEqual (6 , len (result ))
273291
274292 spans = self .recorder .queued_spans ()
275293 self .assertEqual (2 , len (spans ))
@@ -294,6 +312,11 @@ def test_cursor_ctx_mgr(self):
294312 connection = self .db
295313 with connection .cursor () as cursor :
296314 cursor .execute ("""SELECT * from users""" )
315+ affected_rows = cursor .rowcount
316+ result = cursor .fetchone ()
317+
318+ self .assertEqual (1 , affected_rows )
319+ self .assertEqual (6 , len (result ))
297320
298321 spans = self .recorder .queued_spans ()
299322 self .assertEqual (2 , len (spans ))
0 commit comments