Skip to content

Commit 819d4ec

Browse files
committed
Rewrote tests to use pytest exclusively
1 parent 99c6198 commit 819d4ec

15 files changed

+2520
-2640
lines changed

tests/conftest.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,6 @@ def pytest_generate_tests(metafunc):
3434
if 'loop_type' in metafunc.fixturenames:
3535
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
3636
metafunc.parametrize("loop_type", loop_type)
37-
#
38-
# if 'mysql_tag' in metafunc.fixturenames:
39-
# tags = set(metafunc.config.option.mysql_tag)
40-
# if not tags:
41-
# tags = ['5.6', '8.0']
42-
# elif 'all' in tags:
43-
# tags = ['5.6', '5.7', '8.0']
44-
# else:
45-
# tags = list(tags)
46-
# metafunc.parametrize("mysql_tag", tags, scope='session')
4737

4838

4939
# This is here unless someone fixes the generate_tests bit
@@ -172,7 +162,10 @@ def f(**kw):
172162
yield f
173163

174164
for conn in connections:
175-
loop.run_until_complete(conn.ensure_closed())
165+
try:
166+
loop.run_until_complete(conn.ensure_closed())
167+
except ConnectionResetError:
168+
pass
176169

177170

178171
@pytest.yield_fixture
@@ -248,13 +241,13 @@ def mysql_server(unused_port, docker, session_id,
248241
tls_cnf = os.path.join(os.path.dirname(__file__),
249242
'ssl_resources', 'tls.cnf')
250243

251-
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
244+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
252245
ctx.check_hostname = False
253246
ctx.load_verify_locations(cafile=ca_file)
254247
# ctx.verify_mode = ssl.CERT_NONE
255248

256249
container_args = dict(
257-
image='mysql:{}'.format(mysql_tag),
250+
image='{}:{}'.format(mysql_image, mysql_tag),
258251
name='aiomysql-test-server-{}-{}'.format(mysql_tag, session_id),
259252
ports=[3306],
260253
detach=True,

tests/fixtures/my.cnf.tmpl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ port = {port}
77
host = {host}
88
password = {password}
99
database = {db}
10-
socket = /var/run/mysqld/mysqld.sock
1110
default-character-set = utf8
1211

1312
[client_with_unix_socket]
1413
user = {user}
14+
port = {port}
15+
host = {host}
1516
password = {password}
1617
database = {db}
17-
socket = /var/run/mysqld/mysqld.sock
1818
default-character-set = utf8

tests/sa/test_sa_compiled_cache.py

Lines changed: 113 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import asyncio
2-
from aiomysql import sa
1+
import pytest
32
from sqlalchemy import bindparam
3+
from sqlalchemy import MetaData, Table, Column, Integer, String
44

5-
import os
6-
import unittest
5+
from aiomysql import sa
76

8-
from sqlalchemy import MetaData, Table, Column, Integer, String
97

108
meta = MetaData()
119
tbl = Table('sa_tbl_cache_test', meta,
@@ -14,125 +12,113 @@
1412
Column('val', String(255)))
1513

1614

17-
class TestCompiledCache(unittest.TestCase):
18-
def setUp(self):
19-
self.loop = asyncio.new_event_loop()
20-
asyncio.set_event_loop(None)
21-
self.host = os.environ.get('MYSQL_HOST', 'localhost')
22-
self.port = int(os.environ.get('MYSQL_PORT', 3306))
23-
self.user = os.environ.get('MYSQL_USER', 'root')
24-
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
25-
self.password = os.environ.get('MYSQL_PASSWORD', '')
26-
self.engine = self.loop.run_until_complete(self.make_engine())
27-
self.loop.run_until_complete(self.start())
28-
29-
def tearDown(self):
30-
self.engine.terminate()
31-
self.loop.run_until_complete(self.engine.wait_closed())
32-
self.loop.close()
33-
34-
async def make_engine(self, **kwargs):
35-
return (await sa.create_engine(db=self.db,
36-
user=self.user,
37-
password=self.password,
38-
host=self.host,
39-
port=self.port,
40-
loop=self.loop,
41-
minsize=10,
42-
**kwargs))
43-
44-
async def start(self):
45-
async with self.engine.acquire() as conn:
46-
tx = await conn.begin()
47-
await conn.execute("DROP TABLE IF EXISTS "
48-
"sa_tbl_cache_test")
49-
await conn.execute("CREATE TABLE sa_tbl_cache_test"
50-
"(id serial, val varchar(255))")
51-
await conn.execute(tbl.insert().values(val='some_val_1'))
52-
await conn.execute(tbl.insert().values(val='some_val_2'))
53-
await conn.execute(tbl.insert().values(val='some_val_3'))
54-
await tx.commit()
55-
56-
def test_cache(self):
57-
async def go():
58-
cache = dict()
59-
engine = await self.make_engine(compiled_cache=cache)
60-
async with engine.acquire() as conn:
61-
# check select with params not added to cache
62-
q = tbl.select().where(tbl.c.val == 'some_val_1')
63-
cursor = await conn.execute(q)
64-
row = await cursor.fetchone()
65-
self.assertEqual('some_val_1', row.val)
66-
self.assertEqual(0, len(cache))
67-
68-
# check select with bound params added to cache
69-
select_by_val = tbl.select().where(
70-
tbl.c.val == bindparam('value')
71-
)
72-
cursor = await conn.execute(
73-
select_by_val, {'value': 'some_val_3'}
74-
)
75-
row = await cursor.fetchone()
76-
self.assertEqual('some_val_3', row.val)
77-
self.assertEqual(1, len(cache))
78-
79-
cursor = await conn.execute(
80-
select_by_val, value='some_val_2'
81-
)
82-
row = await cursor.fetchone()
83-
self.assertEqual('some_val_2', row.val)
84-
self.assertEqual(1, len(cache))
85-
86-
select_all = tbl.select()
87-
cursor = await conn.execute(select_all)
88-
rows = await cursor.fetchall()
89-
self.assertEqual(3, len(rows))
90-
self.assertEqual(2, len(cache))
91-
92-
# check insert with bound params not added to cache
93-
await conn.execute(tbl.insert().values(val='some_val_4'))
94-
self.assertEqual(2, len(cache))
95-
96-
# check insert with bound params added to cache
97-
q = tbl.insert().values(val=bindparam('value'))
98-
await conn.execute(q, value='some_val_5')
99-
self.assertEqual(3, len(cache))
100-
101-
await conn.execute(q, value='some_val_6')
102-
self.assertEqual(3, len(cache))
103-
104-
await conn.execute(q, {'value': 'some_val_7'})
105-
self.assertEqual(3, len(cache))
106-
107-
cursor = await conn.execute(select_all)
108-
rows = await cursor.fetchall()
109-
self.assertEqual(7, len(rows))
110-
self.assertEqual(3, len(cache))
111-
112-
# check update with params not added to cache
113-
q = tbl.update().where(
114-
tbl.c.val == 'some_val_1'
115-
).values(val='updated_val_1')
116-
await conn.execute(q)
117-
self.assertEqual(3, len(cache))
118-
cursor = await conn.execute(
119-
select_by_val, value='updated_val_1'
120-
)
121-
row = await cursor.fetchone()
122-
self.assertEqual('updated_val_1', row.val)
123-
124-
# check update with bound params added to cache
125-
q = tbl.update().where(
126-
tbl.c.val == bindparam('value')
127-
).values(val=bindparam('update'))
128-
await conn.execute(
129-
q, value='some_val_2', update='updated_val_2'
130-
)
131-
self.assertEqual(4, len(cache))
132-
cursor = await conn.execute(
133-
select_by_val, value='updated_val_2'
134-
)
135-
row = await cursor.fetchone()
136-
self.assertEqual('updated_val_2', row.val)
137-
138-
self.loop.run_until_complete(go())
15+
@pytest.fixture()
16+
def make_engine(mysql_params, connection):
17+
async def _make_engine(**kwargs):
18+
return (await sa.create_engine(db=mysql_params['db'],
19+
user=mysql_params['user'],
20+
password=mysql_params['password'],
21+
host=mysql_params['host'],
22+
port=mysql_params['port'],
23+
minsize=10,
24+
**kwargs))
25+
26+
return _make_engine
27+
28+
29+
async def start(engine):
30+
async with engine.acquire() as conn:
31+
tx = await conn.begin()
32+
await conn.execute("DROP TABLE IF EXISTS "
33+
"sa_tbl_cache_test")
34+
await conn.execute("CREATE TABLE sa_tbl_cache_test"
35+
"(id serial, val varchar(255))")
36+
await conn.execute(tbl.insert().values(val='some_val_1'))
37+
await conn.execute(tbl.insert().values(val='some_val_2'))
38+
await conn.execute(tbl.insert().values(val='some_val_3'))
39+
await tx.commit()
40+
41+
42+
@pytest.mark.run_loop
43+
async def test_dialect(make_engine):
44+
cache = dict()
45+
engine = await make_engine(compiled_cache=cache)
46+
await start(engine)
47+
48+
async with engine.acquire() as conn:
49+
# check select with params not added to cache
50+
q = tbl.select().where(tbl.c.val == 'some_val_1')
51+
cursor = await conn.execute(q)
52+
row = await cursor.fetchone()
53+
assert 'some_val_1' == row.val
54+
assert 0 == len(cache)
55+
56+
# check select with bound params added to cache
57+
select_by_val = tbl.select().where(
58+
tbl.c.val == bindparam('value')
59+
)
60+
cursor = await conn.execute(
61+
select_by_val, {'value': 'some_val_3'}
62+
)
63+
row = await cursor.fetchone()
64+
assert 'some_val_3' == row.val
65+
assert 1 == len(cache)
66+
67+
cursor = await conn.execute(
68+
select_by_val, value='some_val_2'
69+
)
70+
row = await cursor.fetchone()
71+
assert 'some_val_2' == row.val
72+
assert 1 == len(cache)
73+
74+
select_all = tbl.select()
75+
cursor = await conn.execute(select_all)
76+
rows = await cursor.fetchall()
77+
assert 3 == len(rows)
78+
assert 2 == len(cache)
79+
80+
# check insert with bound params not added to cache
81+
await conn.execute(tbl.insert().values(val='some_val_4'))
82+
assert 2 == len(cache)
83+
84+
# check insert with bound params added to cache
85+
q = tbl.insert().values(val=bindparam('value'))
86+
await conn.execute(q, value='some_val_5')
87+
assert 3 == len(cache)
88+
89+
await conn.execute(q, value='some_val_6')
90+
assert 3 == len(cache)
91+
92+
await conn.execute(q, {'value': 'some_val_7'})
93+
assert 3 == len(cache)
94+
95+
cursor = await conn.execute(select_all)
96+
rows = await cursor.fetchall()
97+
assert 7 == len(rows)
98+
assert 3 == len(cache)
99+
100+
# check update with params not added to cache
101+
q = tbl.update().where(
102+
tbl.c.val == 'some_val_1'
103+
).values(val='updated_val_1')
104+
await conn.execute(q)
105+
assert 3 == len(cache)
106+
cursor = await conn.execute(
107+
select_by_val, value='updated_val_1'
108+
)
109+
row = await cursor.fetchone()
110+
assert 'updated_val_1' == row.val
111+
112+
# check update with bound params added to cache
113+
q = tbl.update().where(
114+
tbl.c.val == bindparam('value')
115+
).values(val=bindparam('update'))
116+
await conn.execute(
117+
q, value='some_val_2', update='updated_val_2'
118+
)
119+
assert 4 == len(cache)
120+
cursor = await conn.execute(
121+
select_by_val, value='updated_val_2'
122+
)
123+
row = await cursor.fetchone()
124+
assert 'updated_val_2' == row.val

0 commit comments

Comments
 (0)