Skip to content

Commit 57e1bb8

Browse files
committed
Added initial logic for a new ssl_mode config/cli option
1 parent b0e2962 commit 57e1bb8

File tree

4 files changed

+105
-50
lines changed

4 files changed

+105
-50
lines changed

mycli/main.py

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from prompt_toolkit.lexers import PygmentsLexer
3939
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
4040
import pymysql
41-
from pymysql.constants.ER import ERROR_CODE_ACCESS_DENIED, HANDSHAKE_ERROR
41+
from pymysql.constants.ER import HANDSHAKE_ERROR
4242
from pymysql.cursors import Cursor
4343
import sqlglot
4444
import sqlparse
@@ -155,6 +155,14 @@ def __init__(
155155
self.login_path_as_host = c["main"].as_bool("login_path_as_host")
156156
self.post_redirect_command = c['main'].get('post_redirect_command')
157157

158+
# set ssl_mode if a valid option is provided in a config file, otherwise None
159+
ssl_mode = c["ssl"].get("ssl_mode", None)
160+
if ssl_mode not in ("auto", "on", "off", None):
161+
self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red")
162+
self.ssl_mode = None
163+
else:
164+
self.ssl_mode = ssl_mode
165+
158166
# read from cli argument or user config file
159167
self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output")
160168
self.show_warnings = show_warnings or c["main"].as_bool("show_warnings")
@@ -524,37 +532,67 @@ def connect(
524532
# Connect to the database.
525533

526534
def _connect() -> None:
527-
conn_config = {
528-
"database": database,
529-
"user": user,
530-
"password": passwd,
531-
"host": host,
532-
"port": int_port,
533-
"socket": socket,
534-
"charset": charset,
535-
"local_infile": use_local_infile,
536-
"ssl": ssl_config_or_none,
537-
"ssh_user": ssh_user,
538-
"ssh_host": ssh_host,
539-
"ssh_port": int(ssh_port) if ssh_port else None,
540-
"ssh_password": ssh_password,
541-
"ssh_key_filename": ssh_key_filename,
542-
"init_command": init_command,
543-
}
544535
try:
545-
self.sqlexecute = SQLExecute(**conn_config)
536+
self.sqlexecute = SQLExecute(
537+
database,
538+
user,
539+
passwd,
540+
host,
541+
int_port,
542+
socket,
543+
charset,
544+
use_local_infile,
545+
ssl_config_or_none,
546+
ssh_user,
547+
ssh_host,
548+
int(ssh_port) if ssh_port else None,
549+
ssh_password,
550+
ssh_key_filename,
551+
init_command,
552+
)
546553
except pymysql.OperationalError as e:
547554
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
548555
if password_from_file is not None:
549-
conn_config["password"] = password_from_file
556+
new_passwd = password_from_file
550557
else:
551-
conn_config["password"] = click.prompt(
558+
new_passwd = click.prompt(
552559
f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True
553560
)
554-
self.sqlexecute = SQLExecute(**conn_config)
555-
elif e.args[0] == HANDSHAKE_ERROR:
556-
conn_config["ssl"] = None
557-
self.sqlexecute = SQLExecute(**conn_config)
561+
self.sqlexecute = SQLExecute(
562+
database,
563+
user,
564+
new_passwd,
565+
host,
566+
int_port,
567+
socket,
568+
charset,
569+
use_local_infile,
570+
ssl_config_or_none,
571+
ssh_user,
572+
ssh_host,
573+
int(ssh_port) if ssh_port else None,
574+
ssh_password,
575+
ssh_key_filename,
576+
init_command,
577+
)
578+
elif e.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto":
579+
self.sqlexecute = SQLExecute(
580+
database,
581+
user,
582+
passwd,
583+
host,
584+
int_port,
585+
socket,
586+
charset,
587+
use_local_infile,
588+
None,
589+
ssh_user,
590+
ssh_host,
591+
int(ssh_port) if ssh_port else None,
592+
ssh_password,
593+
ssh_key_filename,
594+
init_command,
595+
)
558596
else:
559597
raise e
560598

@@ -1387,6 +1425,7 @@ def get_last_query(self) -> str | None:
13871425
@click.option("--ssh-key-filename", help="Private key filename (identify file) for the ssh connection.")
13881426
@click.option("--ssh-config-path", help="Path to ssh configuration.", default=os.path.expanduser("~") + "/.ssh/config")
13891427
@click.option("--ssh-config-host", help="Host to connect to ssh server reading from ssh configuration.")
1428+
@click.option("--ssl-mode", "ssl_mode", default="auto", help="Set desired SSL behavior. auto=preferred, on=required, off=off.", type=str)
13901429
@click.option(
13911430
"--ssl/--no-ssl", "ssl_enable", is_flag=True, default=True, help="Enable SSL for connection (automatically enabled with other flags)."
13921431
)
@@ -1455,6 +1494,7 @@ def cli(
14551494
auto_vertical_output: bool,
14561495
show_warnings: bool,
14571496
local_infile: bool,
1497+
ssl_mode: str | None,
14581498
ssl_enable: bool,
14591499
ssl_ca: str | None,
14601500
ssl_capath: str | None,
@@ -1597,19 +1637,29 @@ def cli(
15971637
ssl_verify_server_cert = ssl_verify_server_cert or (params[0].lower() == 'true')
15981638
ssl_enable = True
15991639

1600-
ssl = {
1601-
"enable": ssl_enable,
1602-
"ca": ssl_ca and os.path.expanduser(ssl_ca),
1603-
"cert": ssl_cert and os.path.expanduser(ssl_cert),
1604-
"key": ssl_key and os.path.expanduser(ssl_key),
1605-
"capath": ssl_capath,
1606-
"cipher": ssl_cipher,
1607-
"tls_version": tls_version,
1608-
"check_hostname": ssl_verify_server_cert,
1609-
}
1610-
1611-
# remove empty ssl options
1612-
ssl = {k: v for k, v in ssl.items() if v is not None}
1640+
if ssl_mode not in ("auto", "on", "off"):
1641+
click.secho(f"Invalid option provided for --ssl-mode: {ssl_mode}. See --help for valid options.", err=True, fg="red")
1642+
sys.exit(1)
1643+
1644+
ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option
1645+
1646+
if ssl_mode in ("auto", "on"):
1647+
ssl = {
1648+
"mode": ssl_mode,
1649+
"enable": ssl_enable,
1650+
"ca": ssl_ca and os.path.expanduser(ssl_ca),
1651+
"cert": ssl_cert and os.path.expanduser(ssl_cert),
1652+
"key": ssl_key and os.path.expanduser(ssl_key),
1653+
"capath": ssl_capath,
1654+
"cipher": ssl_cipher,
1655+
"tls_version": tls_version,
1656+
"check_hostname": ssl_verify_server_cert,
1657+
}
1658+
1659+
# remove empty ssl options
1660+
ssl = {k: v for k, v in ssl.items() if v is not None}
1661+
else:
1662+
ssl = None
16131663

16141664
if ssh_config_host:
16151665
ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host)

mycli/myclirc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,11 @@ output.null = "#808080"
190190
[alias_dsn.init-commands]
191191
# Define one or more SQL statements per alias (semicolon-separated).
192192
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"
193+
194+
[ssl]
195+
# Sets the desired behavior for handling secure connections to the database server.
196+
# Possible values:
197+
# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed.
198+
# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established.
199+
# off = do not use SSL. Will fail if the server requires a secure connection.
200+
ssl_mode = auto

test/features/db_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@ def create_db(hostname="localhost", port=3306, username=None, password=None, dbn
1616
:return:
1717
1818
"""
19-
ctx = ssl.create_default_context()
20-
ctx.check_hostname = False
21-
ctx.verify_mode = ssl.VerifyMode.CERT_NONE
2219
cn = pymysql.connect(
23-
host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor, ssl=ctx
20+
host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
2421
)
2522

2623
with cn.cursor() as cr:
@@ -44,9 +41,6 @@ def create_cn(hostname, port, password, username, dbname):
4441
:return: psycopg2.connection
4542
4643
"""
47-
ctx = ssl.create_default_context()
48-
ctx.check_hostname = False
49-
ctx.verify_mode = ssl.VerifyMode.CERT_NONE
5044
cn = pymysql.connect(
5145
host=hostname,
5246
port=port,
@@ -55,7 +49,6 @@ def create_cn(hostname, port, password, username, dbname):
5549
db=dbname,
5650
charset="utf8mb4",
5751
cursorclass=pymysql.cursors.DictCursor,
58-
ssl=ctx,
5952
)
6053

6154
return cn
@@ -71,9 +64,6 @@ def drop_db(hostname="localhost", port=3306, username=None, password=None, dbnam
7164
:param dbname: string
7265
7366
"""
74-
ctx = ssl.create_default_context()
75-
ctx.check_hostname = False
76-
ctx.verify_mode = ssl.VerifyMode.CERT_NONE
7767
cn = pymysql.connect(
7868
host=hostname,
7969
port=port,
@@ -82,7 +72,6 @@ def drop_db(hostname="localhost", port=3306, username=None, password=None, dbnam
8272
db=dbname,
8373
charset="utf8mb4",
8474
cursorclass=pymysql.cursors.DictCursor,
85-
ssl=ctx,
8675
)
8776

8877
with cn.cursor() as cr:

test/myclirc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,11 @@ global_limit = set sql_select_limit=9999
191191
[alias_dsn.init-commands]
192192
# Define one or more SQL statements per alias (semicolon-separated).
193193
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"
194+
195+
[ssl]
196+
# Sets the desired behavior for handling secure connections to the database server.
197+
# Possible values:
198+
# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed.
199+
# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established.
200+
# off = do not use SSL. Will fail if the server requires a secure connection.
201+
ssl_mode = auto

0 commit comments

Comments
 (0)