Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Upcoming Release (TBD)
======================

Features
--------

* Support SSL query parameters on DSNs.

Internal
--------

Expand Down
61 changes: 45 additions & 16 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import itertools
from random import choice
from time import time
from urllib.parse import unquote, urlparse
from urllib.parse import parse_qs, unquote, urlparse

from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors
from cli_helpers.utils import strip_ansi
Expand Down Expand Up @@ -1119,7 +1119,7 @@ def get_last_query(self):
@click.option(
"--ssl-verify-server-cert",
is_flag=True,
help=('Verify server\'s "Common Name" in its cert against hostname used when connecting. This option is disabled by default.'),
help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""),
)
# as of 2016-02-15 revocation list is not supported by underling PyMySQL
# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
Expand Down Expand Up @@ -1240,20 +1240,6 @@ def cli(
# Choose which ever one has a valid value.
database = dbname or database

ssl = {
"enable": ssl_enable,
"ca": ssl_ca and os.path.expanduser(ssl_ca),
"cert": ssl_cert and os.path.expanduser(ssl_cert),
"key": ssl_key and os.path.expanduser(ssl_key),
"capath": ssl_capath,
"cipher": ssl_cipher,
"tls_version": tls_version,
"check_hostname": ssl_verify_server_cert,
}

# remove empty ssl options
ssl = {k: v for k, v in ssl.items() if v is not None}

dsn_uri = None

# Treat the database argument as a DSN alias only if it matches a configured alias
Expand Down Expand Up @@ -1294,6 +1280,49 @@ def cli(
if not port:
port = uri.port

if uri.query:
dsn_params = parse_qs(uri.query)
else:
dsn_params = {}

if dsn_params.get('ssl'):
ssl_enable = ssl_enable or (dsn_params.get('ssl')[0].lower() == 'true')
if dsn_params.get('ssl_ca'):
ssl_ca = ssl_ca or dsn_params.get('ssl_ca')[0]
ssl_enable = True
if dsn_params.get('ssl_capath'):
ssl_capath = ssl_capath or dsn_params.get('ssl_capath')[0]
ssl_enable = True
if dsn_params.get('ssl_cert'):
ssl_cert = ssl_cert or dsn_params.get('ssl_cert')[0]
ssl_enable = True
if dsn_params.get('ssl_key'):
ssl_key = ssl_key or dsn_params.get('ssl_key')[0]
ssl_enable = True
if dsn_params.get('ssl_cipher'):
ssl_cipher = ssl_cipher or dsn_params.get('ssl_cipher')[0]
ssl_enable = True
if dsn_params.get('tls_version'):
tls_version = tls_version or dsn_params.get('tls_version')[0]
ssl_enable = True
if dsn_params.get('ssl_verify_server_cert'):
ssl_verify_server_cert = ssl_verify_server_cert or (dsn_params.get('ssl_verify_server_cert')[0].lower() == 'true')
ssl_enable = True

ssl = {
"enable": ssl_enable,
"ca": ssl_ca and os.path.expanduser(ssl_ca),
"cert": ssl_cert and os.path.expanduser(ssl_cert),
"key": ssl_key and os.path.expanduser(ssl_key),
"capath": ssl_capath,
"cipher": ssl_cipher,
"tls_version": tls_version,
"check_hostname": ssl_verify_server_cert,
}

# remove empty ssl options
ssl = {k: v for k, v in ssl.items() if v is not None}

if ssh_config_host:
ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host)
ssh_host = ssh_host if ssh_host else ssh_config.get("hostname")
Expand Down
31 changes: 31 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,37 @@ def run_query(self, query, new_line=True):
and MockMyCli.connect_args["database"] == "dsn_database"
)

# Use a DSN with query parameters
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?ssl=True"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] == "dsn_passwd"
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 6
and MockMyCli.connect_args["database"] == "dsn_database"
and MockMyCli.connect_args["ssl"]["enable"] is True
)

# When a user uses a DSN with query parameters, and used command line
# arguments, use the command line arguments.
result = runner.invoke(
mycli.main.cli,
args=[
"mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?ssl=False",
"--ssl",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] == "dsn_passwd"
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 6
and MockMyCli.connect_args["database"] == "dsn_database"
and MockMyCli.connect_args["ssl"]["enable"] is True
)


def test_ssh_config(monkeypatch):
# Setup classes to mock mycli.main.MyCli
Expand Down