1212
1313logger = logging .getLogger (__name__ )
1414
15- create_table_query = 'CREATE TABLE IF NOT EXISTS users(id serial primary key, \
16- name varchar(40) NOT NULL, email varchar(40) NOT NULL)'
17-
18- create_proc_query = """
19- CREATE PROCEDURE test_proc(IN t VARCHAR(255))
20- BEGIN
21- SELECT name FROM users WHERE name = t;
22- END
23- """
24-
25- db = pymysql .connect (host = testenv ['mysql_host' ], port = testenv ['mysql_port' ],
26- user = testenv ['mysql_user' ], passwd = testenv ['mysql_pw' ],
27- db = testenv ['mysql_db' ])
28-
29- cursor = db .cursor ()
30- cursor .execute (create_table_query )
31-
32- while cursor .nextset () is not None :
33- pass
34-
35- cursor .execute ('DROP PROCEDURE IF EXISTS test_proc' )
36-
37- while cursor .nextset () is not None :
38- pass
39-
40- cursor .execute (create_proc_query )
41-
42- while cursor .nextset () is not None :
43- pass
44-
45- cursor .close ()
46- db .close ()
47-
4815
4916class TestPyMySQL (unittest .TestCase ):
5017 def setUp (self ):
5118 self .db = pymysql .connect (host = testenv ['mysql_host' ], port = testenv ['mysql_port' ],
5219 user = testenv ['mysql_user' ], passwd = testenv ['mysql_pw' ],
5320 db = testenv ['mysql_db' ])
21+ database_setup_query = """
22+ DROP TABLE IF EXISTS users; |
23+ CREATE TABLE users(
24+ id serial primary key,
25+ name varchar(40) NOT NULL,
26+ email varchar(40) NOT NULL
27+ ); |
28+ INSERT INTO users(name, email) VALUES('kermit', '[email protected] '); | 29+ DROP PROCEDURE IF EXISTS test_proc; |
30+ CREATE PROCEDURE test_proc(IN t VARCHAR(255))
31+ BEGIN
32+ SELECT name FROM users WHERE name = t;
33+ END
34+ """
35+ setup_cursor = self .db .cursor ()
36+ for s in database_setup_query .split ('|' ):
37+ setup_cursor .execute (s )
38+
5439 self .cursor = self .db .cursor ()
5540 self .recorder = tracer .recorder
5641 self .recorder .clear_spans ()
5742 tracer .cur_ctx = None
5843
5944 def tearDown (self ):
60- """ Do nothing for now """
61- return None
45+ if self .cursor and self .cursor .connection .open :
46+ self .cursor .close ()
47+ if self .db and self .db .open :
48+ self .db .close ()
6249
6350 def test_vanilla_query (self ):
64- self .cursor .execute ("""SELECT * from users""" )
51+ affected_rows = self .cursor .execute ("""SELECT * from users""" )
52+ self .assertEqual (1 , affected_rows )
6553 result = self .cursor .fetchone ()
6654 self .assertEqual (3 , len (result ))
6755
@@ -70,10 +58,11 @@ def test_vanilla_query(self):
7058
7159 def test_basic_query (self ):
7260 with tracer .start_active_span ('test' ):
73- result = self .cursor .execute ("""SELECT * from users""" )
74- self .cursor .fetchone ()
61+ affected_rows = self .cursor .execute ("""SELECT * from users""" )
62+ result = self .cursor .fetchone ()
7563
76- self .assertTrue (result >= 0 )
64+ self .assertEqual (1 , affected_rows )
65+ self .assertEqual (3 , len (result ))
7766
7867 spans = self .recorder .queued_spans ()
7968 self .assertEqual (2 , len (spans ))
@@ -95,10 +84,11 @@ def test_basic_query(self):
9584
9685 def test_query_with_params (self ):
9786 with tracer .start_active_span ('test' ):
98- result = self .cursor .execute ("""SELECT * from users where id=1""" )
99- self .cursor .fetchone ()
87+ affected_rows = self .cursor .execute ("""SELECT * from users where id=1""" )
88+ result = self .cursor .fetchone ()
10089
101- self .assertTrue (result >= 0 )
90+ self .assertEqual (1 , affected_rows )
91+ self .assertEqual (3 , len (result ))
10292
10393 spans = self .recorder .queued_spans ()
10494 self .assertEqual (2 , len (spans ))
@@ -120,11 +110,11 @@ def test_query_with_params(self):
120110
121111 def test_basic_insert (self ):
122112 with tracer .start_active_span ('test' ):
123- result = self .cursor .execute (
113+ affected_rows = self .cursor .execute (
124114 """INSERT INTO users(name, email) VALUES(%s, %s)""" ,
125115126116
127- self .assertEqual (1 , result )
117+ self .assertEqual (1 , affected_rows )
128118
129119 spans = self .recorder .queued_spans ()
130120 self .assertEqual (2 , len (spans ))
@@ -146,11 +136,11 @@ def test_basic_insert(self):
146136
147137 def test_executemany (self ):
148138 with tracer .start_active_span ('test' ):
149- result = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
139+ affected_rows = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
150140151141 self .db .commit ()
152142
153- self .assertEqual (2 , result )
143+ self .assertEqual (2 , affected_rows )
154144
155145 spans = self .recorder .queued_spans ()
156146 self .assertEqual (2 , len (spans ))
@@ -172,9 +162,9 @@ def test_executemany(self):
172162
173163 def test_call_proc (self ):
174164 with tracer .start_active_span ('test' ):
175- result = self .cursor .callproc ('test_proc' , ('beaker' ,))
165+ callproc_result = self .cursor .callproc ('test_proc' , ('beaker' ,))
176166
177- self .assertTrue ( result )
167+ self .assertIsInstance ( callproc_result , tuple )
178168
179169 spans = self .recorder .queued_spans ()
180170 self .assertEqual (2 , len (spans ))
@@ -195,15 +185,14 @@ def test_call_proc(self):
195185 self .assertEqual (db_span .data ["mysql" ]["port" ], testenv ['mysql_port' ])
196186
197187 def test_error_capture (self ):
198- result = None
188+ affected_rows = None
199189 try :
200190 with tracer .start_active_span ('test' ):
201- result = self .cursor .execute ("""SELECT * from blah""" )
202- self .cursor .fetchone ()
191+ affected_rows = self .cursor .execute ("""SELECT * from blah""" )
203192 except Exception :
204193 pass
205194
206- self .assertIsNone (result )
195+ self .assertIsNone (affected_rows )
207196
208197 spans = self .recorder .queued_spans ()
209198 self .assertEqual (2 , len (spans ))
@@ -228,8 +217,9 @@ def test_connect_cursor_ctx_mgr(self):
228217 with tracer .start_active_span ("test" ):
229218 with self .db as connection :
230219 with connection .cursor () as cursor :
231- cursor .execute ("""SELECT * from users""" )
220+ affected_rows = cursor .execute ("""SELECT * from users""" )
232221
222+ self .assertEqual (1 , affected_rows )
233223 spans = self .recorder .queued_spans ()
234224 self .assertEqual (2 , len (spans ))
235225
0 commit comments