Skip to content

Commit 6532b0d

Browse files
committed
Normalised type of _auth_plugin_used. Added SHA265 tests
1 parent b9b7b9a commit 6532b0d

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

aiomysql/connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,10 @@ async def _process_auth(self, plugin_name, auth_packet):
798798
# These auth plugins do their own packet handling
799799
if plugin_name == b"caching_sha2_password":
800800
await self.caching_sha2_password_auth(auth_packet)
801-
self._auth_plugin_used = plugin_name
801+
self._auth_plugin_used = plugin_name.decode()
802802
elif plugin_name == b"sha256_password":
803803
await self.sha256_password_auth(auth_packet)
804-
self._auth_plugin_used = plugin_name
804+
self._auth_plugin_used = plugin_name.decode()
805805
else:
806806

807807
if plugin_name == b"mysql_native_password":
@@ -832,7 +832,7 @@ async def _process_auth(self, plugin_name, auth_packet):
832832
pkt = await self._read_packet()
833833
pkt.check_error()
834834

835-
self._auth_plugin_used = plugin_name
835+
self._auth_plugin_used = plugin_name.decode()
836836

837837
return pkt
838838

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ def docker():
218218
return APIClient(version='auto')
219219

220220

221+
@pytest.fixture(autouse=True)
222+
def ensure_mysql_verison(request, mysql_tag):
223+
if request.node.get_marker('mysql_verison'):
224+
if request.node.get_marker('mysql_verison').args[0] != mysql_tag:
225+
pytest.skip('Not applicable for '
226+
'MySQL version: {0}'.format(mysql_tag))
227+
228+
221229
@pytest.fixture(scope='session')
222230
def mysql_server(unused_port, docker, session_id, mysql_tag, request):
223231
if not request.config.option.no_pull:
@@ -295,6 +303,32 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request):
295303
assert result['Value'].startswith('TLS'), \
296304
"Not connected to the database with TLS"
297305

306+
# Create Databases
307+
cursor.execute('CREATE DATABASE test_pymysql '
308+
'DEFAULT CHARACTER SET utf8 '
309+
'DEFAULT COLLATE utf8_general_ci;')
310+
cursor.execute('CREATE DATABASE test_pymysql2 '
311+
'DEFAULT CHARACTER SET utf8 '
312+
'DEFAULT COLLATE utf8_general_ci;')
313+
314+
# Do MySQL8+ Specific Setup
315+
if mysql_tag in ('8.0',):
316+
# Create Users to test SHA256
317+
cursor.execute('CREATE USER user_sha256 '
318+
'IDENTIFIED WITH "sha256_password" '
319+
'BY "pass_sha256"')
320+
cursor.execute('CREATE USER nopass_sha256 '
321+
'IDENTIFIED WITH "sha256_password"')
322+
cursor.execute('CREATE USER user_caching_sha2 '
323+
'IDENTIFIED '
324+
'WITH "caching_sha2_password" '
325+
'BY "pass_caching_sha2"')
326+
cursor.execute('CREATE USER nopass_caching_sha2 '
327+
'IDENTIFIED '
328+
'WITH "caching_sha2_password" '
329+
'PASSWORD EXPIRE NEVER')
330+
cursor.execute('FLUSH PRIVILEGES')
331+
298332
break
299333
except Exception as err:
300334
time.sleep(delay)

tests/test_sha_connection.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import copy
2+
from aiomysql import create_pool
3+
4+
import pytest
5+
6+
7+
@pytest.mark.mysql_verison('8.0')
8+
@pytest.mark.run_loop
9+
@pytest.mark.parametrize("user,password,plugin", [
10+
("nopass_sha256", None, 'sha256_password'),
11+
("user_sha256", 'pass_sha256', 'sha256_password'),
12+
("nopass_caching_sha2", None, 'caching_sha2_password'),
13+
("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'),
14+
])
15+
async def test_sha(mysql_server, loop, user, password, plugin):
16+
connection_data = copy.copy(mysql_server['conn_params'])
17+
connection_data['user'] = user
18+
connection_data['password'] = password
19+
20+
async with create_pool(**connection_data,
21+
loop=loop) as pool:
22+
async with pool.get() as conn:
23+
# User doesnt have any permissions to look at DBs
24+
# But as 8.0 will default to caching_sha2_password
25+
assert conn._auth_plugin_used == plugin

tests/test_ssl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ async def test_auth_plugin_renegotiation(mysql_server, loop):
5454
'Server did not ask for native auth'
5555
# Check we actually used the servers default plugin
5656
assert conn._auth_plugin_used in (
57-
b'mysql_native_password', b'caching_sha2_password'), \
57+
'mysql_native_password', 'caching_sha2_password'), \
5858
'Client did not renegotiate with server\'s default auth'

0 commit comments

Comments
 (0)