diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83465c67..b8db8d57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,8 @@ jobs: image: "${{ join(matrix.db, ':') }}" ports: - 3306:3306 + volumes: + - "/tmp/run-${{ join(matrix.db, '-') }}/:/socket-mount/" options: '--name=mysqld' env: MYSQL_ROOT_PASSWORD: rootpw @@ -104,6 +106,19 @@ jobs: docker container stop mysqld docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl" mysqld:/etc/mysql/ssl docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf + + # use custom socket path + # we need to ensure that the socket path is writable for the user running the DB process in the container + sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }} + + # mysql 5.7 container overrides the socket path in /etc/mysql/mysql.conf.d/mysqld.cnf + if [ "${{ join(matrix.db, '-') }}" = "mysql-5.7" ] + then + docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/mysql.conf.d/zz-aiomysql-socket.cnf + else + docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf + fi + docker container start mysqld # ensure server is started up @@ -119,7 +134,7 @@ jobs: run: | # timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs timeout --preserve-status --signal=INT --verbose 5m \ - pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" + pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" env: PYTHONUNBUFFERED: 1 DB: '${{ matrix.db[0] }}' diff --git a/CHANGES.txt b/CHANGES.txt index b537f279..9262950a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -11,6 +11,7 @@ To be included in 1.0.0 (unreleased) * Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660 * Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660 * Unix sockets are now internally considered secure, allowing sha256_password and caching_sha2_password auth methods to be used #695 +* Test suite now also tests unix socket connections #686 0.0.22 (2021-11-14) diff --git a/tests/conftest.py b/tests/conftest.py index a42172c1..d6b0a923 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,17 @@ def pytest_generate_tests(metafunc): mysql_addresses = [] ids = [] + opt_mysql_unix_socket = \ + list(metafunc.config.getoption("mysql_unix_socket")) + for i in range(len(opt_mysql_unix_socket)): + if "=" in opt_mysql_unix_socket[i]: + label, path = opt_mysql_unix_socket[i].split("=", 1) + mysql_addresses.append(path) + ids.append(label) + else: + mysql_addresses.append(opt_mysql_unix_socket[i]) + ids.append("unix{}".format(i)) + opt_mysql_address = list(metafunc.config.getoption("mysql_address")) for i in range(len(opt_mysql_address)): if "=" in opt_mysql_address[i]: @@ -143,6 +154,12 @@ def pytest_addoption(parser): default=[], help="list of addresses to connect to: [name=]host[:port]", ) + parser.addoption( + "--mysql-unix-socket", + action="append", + default=[], + help="list of unix sockets to connect to: [name=]/path/to/socket", + ) @pytest.fixture @@ -250,23 +267,30 @@ def ensure_mysql_version(request, mysql_image, mysql_tag): @pytest.fixture(scope='session') def mysql_server(mysql_image, mysql_tag, mysql_address): - ssl_directory = os.path.join(os.path.dirname(__file__), - 'ssl_resources', 'ssl') - ca_file = os.path.join(ssl_directory, 'ca.pem') + unix_socket = type(mysql_address) is str + + if not unix_socket: + ssl_directory = os.path.join(os.path.dirname(__file__), + 'ssl_resources', 'ssl') + ca_file = os.path.join(ssl_directory, 'ca.pem') - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - ctx.check_hostname = False - ctx.load_verify_locations(cafile=ca_file) - # ctx.verify_mode = ssl.CERT_NONE + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + ctx.check_hostname = False + ctx.load_verify_locations(cafile=ca_file) + # ctx.verify_mode = ssl.CERT_NONE server_params = { - 'host': mysql_address[0], - 'port': mysql_address[1], 'user': 'root', 'password': os.environ.get("MYSQL_ROOT_PASSWORD"), - 'ssl': ctx, } + if unix_socket: + server_params["unix_socket"] = mysql_address + else: + server_params["host"] = mysql_address[0] + server_params["port"] = mysql_address[1] + server_params["ssl"] = ctx + try: connection = pymysql.connect( db='mysql', @@ -275,21 +299,22 @@ def mysql_server(mysql_image, mysql_tag, mysql_address): **server_params) with connection.cursor() as cursor: - cursor.execute("SHOW VARIABLES LIKE '%ssl%';") + if not unix_socket: + cursor.execute("SHOW VARIABLES LIKE '%ssl%';") - result = cursor.fetchall() - result = {item['Variable_name']: - item['Value'] for item in result} + result = cursor.fetchall() + result = {item['Variable_name']: + item['Value'] for item in result} - assert result['have_ssl'] == "YES", \ - "SSL Not Enabled on MySQL" + assert result['have_ssl'] == "YES", \ + "SSL Not Enabled on MySQL" - cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") + cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") - result = cursor.fetchone() - # As we connected with TLS, it should start with that :D - assert result['Value'].startswith('TLS'), \ - "Not connected to the database with TLS" + result = cursor.fetchone() + # As we connected with TLS, it should start with that :D + assert result['Value'].startswith('TLS'), \ + "Not connected to the database with TLS" # Drop possibly existing old databases cursor.execute('DROP DATABASE IF EXISTS test_pymysql;') diff --git a/tests/fixtures/my.cnf.tmpl b/tests/fixtures/my.cnf.tcp.tmpl similarity index 100% rename from tests/fixtures/my.cnf.tmpl rename to tests/fixtures/my.cnf.tcp.tmpl diff --git a/tests/fixtures/my.cnf.unix.tmpl b/tests/fixtures/my.cnf.unix.tmpl new file mode 100644 index 00000000..2aad4432 --- /dev/null +++ b/tests/fixtures/my.cnf.unix.tmpl @@ -0,0 +1,16 @@ +# +# The MySQL database server configuration file. +# +[client] +user = {user} +socket = {unix_socket} +password = {password} +database = {db} +default-character-set = utf8 + +[client_with_unix_socket] +user = {user} +socket = {unix_socket} +password = {password} +database = {db} +default-character-set = utf8 diff --git a/tests/sa/test_sa_compiled_cache.py b/tests/sa/test_sa_compiled_cache.py index e8c0f5f2..38906551 100644 --- a/tests/sa/test_sa_compiled_cache.py +++ b/tests/sa/test_sa_compiled_cache.py @@ -15,12 +15,19 @@ @pytest.fixture() def make_engine(mysql_params, connection): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/sa/test_sa_default.py b/tests/sa/test_sa_default.py index 42c34f5b..e5f270ec 100644 --- a/tests/sa/test_sa_default.py +++ b/tests/sa/test_sa_default.py @@ -22,12 +22,19 @@ @pytest.fixture() def make_engine(mysql_params, connection): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/sa/test_sa_engine.py b/tests/sa/test_sa_engine.py index e514260d..ed74a96d 100644 --- a/tests/sa/test_sa_engine.py +++ b/tests/sa/test_sa_engine.py @@ -15,12 +15,19 @@ @pytest.fixture() def make_engine(connection, mysql_params): async def _make_engine(**kwargs): + if "unix_socket" in mysql_params: + conn_args = {"unix_socket": mysql_params["unix_socket"]} + else: + conn_args = { + "host": mysql_params['host'], + "port": mysql_params['port'], + } + return (await sa.create_engine(db=mysql_params['db'], user=mysql_params['user'], password=mysql_params['password'], - host=mysql_params['host'], - port=mysql_params['port'], minsize=10, + **conn_args, **kwargs)) return _make_engine diff --git a/tests/ssl_resources/socket.cnf b/tests/ssl_resources/socket.cnf new file mode 100644 index 00000000..32100e93 --- /dev/null +++ b/tests/ssl_resources/socket.cnf @@ -0,0 +1,2 @@ +[mysqld] +socket = /socket-mount/mysql.sock diff --git a/tests/test_connection.py b/tests/test_connection.py index 075039d0..af6788f3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,7 +10,13 @@ @pytest.fixture() def fill_my_cnf(mysql_params): tests_root = os.path.abspath(os.path.dirname(__file__)) - path1 = os.path.join(tests_root, 'fixtures/my.cnf.tmpl') + + if "unix_socket" in mysql_params: + tmpl_path = "fixtures/my.cnf.unix.tmpl" + else: + tmpl_path = "fixtures/my.cnf.tcp.tmpl" + + path1 = os.path.join(tests_root, tmpl_path) path2 = os.path.join(tests_root, 'fixtures/my.cnf') with open(path1) as f1: tmpl = f1.read() @@ -31,8 +37,11 @@ async def test_config_file(fill_my_cnf, connection_creator, mysql_params): path = os.path.join(tests_root, 'fixtures/my.cnf') conn = await connection_creator(read_default_file=path) - assert conn.host == mysql_params['host'] - assert conn.port == mysql_params['port'] + if "unix_socket" in mysql_params: + assert conn.unix_socket == mysql_params["unix_socket"] + else: + assert conn.host == mysql_params['host'] + assert conn.port == mysql_params['port'] assert conn.user, mysql_params['user'] # make sure connection is working @@ -167,12 +176,15 @@ async def test_connection_gone_away(connection_creator): @pytest.mark.run_loop -async def test_connection_info_methods(connection_creator): +async def test_connection_info_methods(connection_creator, mysql_params): conn = await connection_creator() # trhead id is int assert isinstance(conn.thread_id(), int) assert conn.character_set_name() in ('latin1', 'utf8mb4') - assert str(conn.port) in conn.get_host_info() + if "unix_socket" in mysql_params: + assert mysql_params["unix_socket"] in conn.get_host_info() + else: + assert str(conn.port) in conn.get_host_info() assert isinstance(conn.get_server_info(), str) # protocol id is int assert isinstance(conn.get_proto_info(), int) @@ -200,8 +212,11 @@ async def test_connection_ping(connection_creator): @pytest.mark.run_loop async def test_connection_properties(connection_creator, mysql_params): conn = await connection_creator() - assert conn.host == mysql_params['host'] - assert conn.port == mysql_params['port'] + if "unix_socket" in mysql_params: + assert conn.unix_socket == mysql_params["unix_socket"] + else: + assert conn.host == mysql_params['host'] + assert conn.port == mysql_params['port'] assert conn.user == mysql_params['user'] assert conn.db == mysql_params['db'] assert conn.echo is False diff --git a/tests/test_issues.py b/tests/test_issues.py index 942bc8ed..c25e292f 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -184,7 +184,7 @@ async def test_issue_17(connection, connection_creator, mysql_params): async def test_issue_34(connection_creator): try: await connection_creator(host="localhost", port=1237, - user="root") + user="root", unix_socket=None) pytest.fail() except aiomysql.OperationalError as e: assert 2003 == e.args[0] diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index eb57ec3d..0789d162 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -39,6 +39,13 @@ async def test_sha256_nopw(mysql_server, loop): @pytest.mark.mysql_version('mysql', '8.0') @pytest.mark.run_loop async def test_sha256_pw(mysql_server, loop): + # https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html + # Unlike caching_sha2_password, the sha256_password plugin does not treat + # shared-memory connections as secure, even though share-memory transport + # is secure by default. + if "unix_socket" in mysql_server['conn_params']: + pytest.skip("sha256_password is not supported on unix sockets") + connection_data = copy.copy(mysql_server['conn_params']) connection_data['user'] = 'user_sha256' connection_data['password'] = 'pass_sha256' diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ff1ea740..140c164f 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -4,7 +4,10 @@ @pytest.mark.run_loop -async def test_tls_connect(mysql_server, loop): +async def test_tls_connect(mysql_server, loop, mysql_params): + if "unix_socket" in mysql_params: + pytest.skip("TLS is not supported on unix sockets") + async with create_pool(**mysql_server['conn_params'], loop=loop) as pool: async with pool.get() as conn: @@ -32,7 +35,10 @@ async def test_tls_connect(mysql_server, loop): # MySQL will get you to renegotiate if sent a cleartext password @pytest.mark.run_loop -async def test_auth_plugin_renegotiation(mysql_server, loop): +async def test_auth_plugin_renegotiation(mysql_server, loop, mysql_params): + if "unix_socket" in mysql_params: + pytest.skip("TLS is not supported on unix sockets") + async with create_pool(**mysql_server['conn_params'], auth_plugin='mysql_clear_password', loop=loop) as pool: