Skip to content

Commit df44b06

Browse files
committed
fix discovering .cnf files that are included with !includedirs
1 parent a15df74 commit df44b06

File tree

4 files changed

+96
-31
lines changed

4 files changed

+96
-31
lines changed

mycli/config.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import shutil
2+
from copy import copy
23
from io import BytesIO, TextIOWrapper
34
import logging
45
import os
@@ -58,12 +59,34 @@ def read_config_file(f, list_values=True):
5859
return config
5960

6061

62+
def get_included_configs(config_path) -> list:
63+
"""Get a list of configuration files that are included into config_path
64+
with !includedir directive."""
65+
if not os.path.exists(config_path):
66+
return []
67+
included_configs = []
68+
with open(config_path) as f:
69+
include_directives = filter(
70+
lambda s: s.startswith('!includedir'),
71+
f
72+
)
73+
dirs = map(lambda s: s.strip().split()[-1], include_directives)
74+
dirs = filter(os.path.isdir, dirs)
75+
for dir in dirs:
76+
for filename in os.listdir(dir):
77+
if filename.endswith('.cnf'):
78+
included_configs.append(os.path.join(dir, filename))
79+
return included_configs
80+
81+
6182
def read_config_files(files, list_values=True):
6283
"""Read and merge a list of config files."""
6384

6485
config = ConfigObj(list_values=list_values)
65-
66-
for _file in files:
86+
_files = copy(files)
87+
while _files:
88+
_file = _files.pop(0)
89+
_files = get_included_configs(_file) + _files
6790
_config = read_config_file(_file, list_values=list_values)
6891
if bool(_config) is True:
6992
config.merge(_config)

mycli/main.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def connect(self, database='', user='', passwd='', host='', port='',
386386
if port or host:
387387
socket = ''
388388
else:
389-
socket = socket or cnf['socket']
389+
socket = socket or cnf['socket'] or guess_socket_location()
390390
user = user or cnf['user'] or os.getenv('USER')
391391
host = host or cnf['host']
392392
port = port or cnf['port']
@@ -430,27 +430,13 @@ def _connect():
430430
else:
431431
raise e
432432

433-
def _fallback_to_tcp_ip():
434-
self.echo(
435-
'Retrying over TCP/IP', err=True)
436-
437-
# Else fall back to TCP/IP localhost
438-
nonlocal socket, host, port
439-
socket = ""
440-
host = 'localhost'
441-
port = 3306
442-
_connect()
443-
444433
try:
445-
if (host is None) and not WIN:
446-
# Try a sensible default socket first (simplifies auth)
447-
# If we get a connection error, try tcp/ip localhost
434+
if not WIN and socket:
435+
socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
436+
self.echo(
437+
f"Connecting to socket {socket}, owned by user {socket_owner}")
448438
try:
449-
socket = socket or guess_socket_location()
450439
_connect()
451-
except FileNotFoundError:
452-
self.echo('Failed to find socket file at default locations')
453-
_fallback_to_tcp_ip()
454440
except OperationalError as e:
455441
# These are "Can't open socket" and 2x "Can't connect"
456442
if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
@@ -461,15 +447,16 @@ def _fallback_to_tcp_ip():
461447
self.echo(
462448
"Failed to connect to local MySQL server through socket '{}':".format(socket))
463449
self.echo(str(e), err=True)
464-
_fallback_to_tcp_ip()
450+
self.echo(
451+
'Retrying over TCP/IP', err=True)
452+
453+
# Else fall back to TCP/IP localhost
454+
socket = ""
455+
host = 'localhost'
456+
port = 3306
457+
_connect()
465458
else:
466459
raise e
467-
else:
468-
socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
469-
self.echo(
470-
"Using socket {}, owned by user {}".format(
471-
socket, socket_owner)
472-
)
473460
else:
474461
host = host or 'localhost'
475462
port = port or 3306
@@ -1009,6 +996,9 @@ def get_last_query(self):
1009996
@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
1010997
@click.option('--ssh-password', help='Password to connect to ssh server.')
1011998
@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
999+
@click.option('--ssh-config-path', help='Path to ssh configuration.',
1000+
default=os.path.expanduser('~') + '/.ssh/config')
1001+
@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.')
10121002
@click.option('--ssl-ca', help='CA file in PEM format.',
10131003
type=click.Path(exists=True))
10141004
@click.option('--ssl-capath', help='CA directory.')
@@ -1030,6 +1020,8 @@ def get_last_query(self):
10301020
help='Use DSN configured into the [alias_dsn] section of myclirc file.')
10311021
@click.option('--list-dsn', 'list_dsn', is_flag=True,
10321022
help='list of DSN configured into the [alias_dsn] section of myclirc file.')
1023+
@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True,
1024+
help='list ssh configurations in the ssh config (requires paramiko).')
10331025
@click.option('-R', '--prompt', 'prompt',
10341026
help='Prompt format (Default: "{0}").'.format(
10351027
MyCli.default_prompt))
@@ -1062,7 +1054,7 @@ def cli(database, user, host, port, socket, password, dbname,
10621054
ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
10631055
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
10641056
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
1065-
ssh_key_filename):
1057+
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host):
10661058
"""A MySQL terminal client with auto-completion and syntax highlighting.
10671059
10681060
\b
@@ -1099,6 +1091,31 @@ def cli(database, user, host, port, socket, password, dbname,
10991091
else:
11001092
click.secho(alias)
11011093
sys.exit(0)
1094+
if list_ssh_config:
1095+
if not paramiko:
1096+
click.secho(
1097+
"This features requires paramiko. Please install paramiko and try again.",
1098+
err=True, fg='red'
1099+
)
1100+
exit(1)
1101+
try:
1102+
ssh_config = paramiko.config.SSHConfig().from_path(ssh_config_path)
1103+
except paramiko.ssh_exception.ConfigParseError as err:
1104+
click.secho('Invalid SSH configuration file. '
1105+
'Please check the SSH configuration file.',
1106+
err=True, fg='red')
1107+
exit(1)
1108+
except FileNotFoundError as e:
1109+
click.secho(str(e), err=True, fg='red')
1110+
exit(1)
1111+
for host in ssh_config.get_hostnames():
1112+
if verbose:
1113+
host_config = ssh_config.lookup(host)
1114+
click.secho("{} : {}".format(
1115+
host, host_config.get('hostname')))
1116+
else:
1117+
click.secho(host)
1118+
sys.exit(0)
11021119
# Choose which ever one has a valid value.
11031120
database = dbname or database
11041121

@@ -1149,6 +1166,32 @@ def cli(database, user, host, port, socket, password, dbname,
11491166
if not port:
11501167
port = uri.port
11511168

1169+
if ssh_config_host:
1170+
if not paramiko:
1171+
click.secho(
1172+
"This features requires paramiko. Please install paramiko and try again.",
1173+
err=True, fg='red'
1174+
)
1175+
exit(1)
1176+
try:
1177+
ssh_config = paramiko.config.SSHConfig().from_path(ssh_config_path)
1178+
except paramiko.ssh_exception.ConfigParseError as err:
1179+
click.secho('Invalid SSH configuration file. '
1180+
'Please check the SSH configuration file.',
1181+
err=True, fg='red')
1182+
exit(1)
1183+
except FileNotFoundError as e:
1184+
click.secho(str(e), err=True, fg='red')
1185+
exit(1)
1186+
ssh_config = ssh_config.lookup(ssh_config_host)
1187+
ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
1188+
ssh_user = ssh_user if ssh_user else ssh_config.get('user')
1189+
if ssh_config.get('port') and ssh_port == 22:
1190+
# port has a default value, overwrite it if it's in the config
1191+
ssh_port = int(ssh_config.get('port'))
1192+
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
1193+
'identityfile', [''])[0]
1194+
11521195
if not paramiko and ssh_host:
11531196
click.secho(
11541197
"Cannot use SSH transport because paramiko isn't installed, "

mycli/packages/filepaths.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ def guess_socket_location():
103103
if name.startswith("mysql") and ext in ('.socket', '.sock'):
104104
return os.path.join(r, filename)
105105
dirs[:] = [d for d in dirs if d.startswith("mysql")]
106-
raise FileNotFoundError
106+
return None

mycli/sqlexecute.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def run(self, statement):
191191
if not cur.nextset() or (not cur.rowcount and cur.description is None):
192192
break
193193

194-
195194
def get_result(self, cursor):
196195
"""Get the current result's data from the cursor."""
197196
title = headers = None

0 commit comments

Comments
 (0)