|
38 | 38 | from prompt_toolkit.lexers import PygmentsLexer |
39 | 39 | from prompt_toolkit.shortcuts import CompleteStyle, PromptSession |
40 | 40 | import pymysql |
41 | | -from pymysql.constants.ER import ERROR_CODE_ACCESS_DENIED, HANDSHAKE_ERROR |
| 41 | +from pymysql.constants.ER import HANDSHAKE_ERROR |
42 | 42 | from pymysql.cursors import Cursor |
43 | 43 | import sqlglot |
44 | 44 | import sqlparse |
@@ -155,6 +155,14 @@ def __init__( |
155 | 155 | self.login_path_as_host = c["main"].as_bool("login_path_as_host") |
156 | 156 | self.post_redirect_command = c['main'].get('post_redirect_command') |
157 | 157 |
|
| 158 | + # set ssl_mode if a valid option is provided in a config file, otherwise None |
| 159 | + ssl_mode = c["ssl"].get("ssl_mode", None) |
| 160 | + if ssl_mode not in ("auto", "on", "off", None): |
| 161 | + self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red") |
| 162 | + self.ssl_mode = None |
| 163 | + else: |
| 164 | + self.ssl_mode = ssl_mode |
| 165 | + |
158 | 166 | # read from cli argument or user config file |
159 | 167 | self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") |
160 | 168 | self.show_warnings = show_warnings or c["main"].as_bool("show_warnings") |
@@ -524,37 +532,67 @@ def connect( |
524 | 532 | # Connect to the database. |
525 | 533 |
|
526 | 534 | def _connect() -> None: |
527 | | - conn_config = { |
528 | | - "database": database, |
529 | | - "user": user, |
530 | | - "password": passwd, |
531 | | - "host": host, |
532 | | - "port": int_port, |
533 | | - "socket": socket, |
534 | | - "charset": charset, |
535 | | - "local_infile": use_local_infile, |
536 | | - "ssl": ssl_config_or_none, |
537 | | - "ssh_user": ssh_user, |
538 | | - "ssh_host": ssh_host, |
539 | | - "ssh_port": int(ssh_port) if ssh_port else None, |
540 | | - "ssh_password": ssh_password, |
541 | | - "ssh_key_filename": ssh_key_filename, |
542 | | - "init_command": init_command, |
543 | | - } |
544 | 535 | try: |
545 | | - self.sqlexecute = SQLExecute(**conn_config) |
| 536 | + self.sqlexecute = SQLExecute( |
| 537 | + database, |
| 538 | + user, |
| 539 | + passwd, |
| 540 | + host, |
| 541 | + int_port, |
| 542 | + socket, |
| 543 | + charset, |
| 544 | + use_local_infile, |
| 545 | + ssl_config_or_none, |
| 546 | + ssh_user, |
| 547 | + ssh_host, |
| 548 | + int(ssh_port) if ssh_port else None, |
| 549 | + ssh_password, |
| 550 | + ssh_key_filename, |
| 551 | + init_command, |
| 552 | + ) |
546 | 553 | except pymysql.OperationalError as e: |
547 | 554 | if e.args[0] == ERROR_CODE_ACCESS_DENIED: |
548 | 555 | if password_from_file is not None: |
549 | | - conn_config["password"] = password_from_file |
| 556 | + new_passwd = password_from_file |
550 | 557 | else: |
551 | | - conn_config["password"] = click.prompt( |
| 558 | + new_passwd = click.prompt( |
552 | 559 | f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True |
553 | 560 | ) |
554 | | - self.sqlexecute = SQLExecute(**conn_config) |
555 | | - elif e.args[0] == HANDSHAKE_ERROR: |
556 | | - conn_config["ssl"] = None |
557 | | - self.sqlexecute = SQLExecute(**conn_config) |
| 561 | + self.sqlexecute = SQLExecute( |
| 562 | + database, |
| 563 | + user, |
| 564 | + new_passwd, |
| 565 | + host, |
| 566 | + int_port, |
| 567 | + socket, |
| 568 | + charset, |
| 569 | + use_local_infile, |
| 570 | + ssl_config_or_none, |
| 571 | + ssh_user, |
| 572 | + ssh_host, |
| 573 | + int(ssh_port) if ssh_port else None, |
| 574 | + ssh_password, |
| 575 | + ssh_key_filename, |
| 576 | + init_command, |
| 577 | + ) |
| 578 | + elif e.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": |
| 579 | + self.sqlexecute = SQLExecute( |
| 580 | + database, |
| 581 | + user, |
| 582 | + passwd, |
| 583 | + host, |
| 584 | + int_port, |
| 585 | + socket, |
| 586 | + charset, |
| 587 | + use_local_infile, |
| 588 | + None, |
| 589 | + ssh_user, |
| 590 | + ssh_host, |
| 591 | + int(ssh_port) if ssh_port else None, |
| 592 | + ssh_password, |
| 593 | + ssh_key_filename, |
| 594 | + init_command, |
| 595 | + ) |
558 | 596 | else: |
559 | 597 | raise e |
560 | 598 |
|
@@ -1387,6 +1425,7 @@ def get_last_query(self) -> str | None: |
1387 | 1425 | @click.option("--ssh-key-filename", help="Private key filename (identify file) for the ssh connection.") |
1388 | 1426 | @click.option("--ssh-config-path", help="Path to ssh configuration.", default=os.path.expanduser("~") + "/.ssh/config") |
1389 | 1427 | @click.option("--ssh-config-host", help="Host to connect to ssh server reading from ssh configuration.") |
| 1428 | +@click.option("--ssl-mode", "ssl_mode", default="auto", help="Set desired SSL behavior. auto=preferred, on=required, off=off.", type=str) |
1390 | 1429 | @click.option( |
1391 | 1430 | "--ssl/--no-ssl", "ssl_enable", is_flag=True, default=True, help="Enable SSL for connection (automatically enabled with other flags)." |
1392 | 1431 | ) |
@@ -1455,6 +1494,7 @@ def cli( |
1455 | 1494 | auto_vertical_output: bool, |
1456 | 1495 | show_warnings: bool, |
1457 | 1496 | local_infile: bool, |
| 1497 | + ssl_mode: str | None, |
1458 | 1498 | ssl_enable: bool, |
1459 | 1499 | ssl_ca: str | None, |
1460 | 1500 | ssl_capath: str | None, |
@@ -1597,19 +1637,29 @@ def cli( |
1597 | 1637 | ssl_verify_server_cert = ssl_verify_server_cert or (params[0].lower() == 'true') |
1598 | 1638 | ssl_enable = True |
1599 | 1639 |
|
1600 | | - ssl = { |
1601 | | - "enable": ssl_enable, |
1602 | | - "ca": ssl_ca and os.path.expanduser(ssl_ca), |
1603 | | - "cert": ssl_cert and os.path.expanduser(ssl_cert), |
1604 | | - "key": ssl_key and os.path.expanduser(ssl_key), |
1605 | | - "capath": ssl_capath, |
1606 | | - "cipher": ssl_cipher, |
1607 | | - "tls_version": tls_version, |
1608 | | - "check_hostname": ssl_verify_server_cert, |
1609 | | - } |
1610 | | - |
1611 | | - # remove empty ssl options |
1612 | | - ssl = {k: v for k, v in ssl.items() if v is not None} |
| 1640 | + if ssl_mode not in ("auto", "on", "off"): |
| 1641 | + click.secho(f"Invalid option provided for --ssl-mode: {ssl_mode}. See --help for valid options.", err=True, fg="red") |
| 1642 | + sys.exit(1) |
| 1643 | + |
| 1644 | + ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option |
| 1645 | + |
| 1646 | + if ssl_mode in ("auto", "on"): |
| 1647 | + ssl = { |
| 1648 | + "mode": ssl_mode, |
| 1649 | + "enable": ssl_enable, |
| 1650 | + "ca": ssl_ca and os.path.expanduser(ssl_ca), |
| 1651 | + "cert": ssl_cert and os.path.expanduser(ssl_cert), |
| 1652 | + "key": ssl_key and os.path.expanduser(ssl_key), |
| 1653 | + "capath": ssl_capath, |
| 1654 | + "cipher": ssl_cipher, |
| 1655 | + "tls_version": tls_version, |
| 1656 | + "check_hostname": ssl_verify_server_cert, |
| 1657 | + } |
| 1658 | + |
| 1659 | + # remove empty ssl options |
| 1660 | + ssl = {k: v for k, v in ssl.items() if v is not None} |
| 1661 | + else: |
| 1662 | + ssl = None |
1613 | 1663 |
|
1614 | 1664 | if ssh_config_host: |
1615 | 1665 | ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host) |
|
0 commit comments