Skip to content

Commit e9dbdf4

Browse files
authored
Set correct database on reconnect when applicable. (#1439)
* Set correct database on reconnect when applicable. * Added test for reconnect to verify previously selected database is still selected
1 parent fc4ce84 commit e9dbdf4

File tree

5 files changed

+46
-6
lines changed

5 files changed

+46
-6
lines changed

changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Internal
88
Bug Fixes
99
--------
1010
* Update watch query output to display the correct execution time on all iterations (#763).
11-
11+
* Use correct database (if applicable) when reconnecting after a connection loss (#1437).
1212

1313
1.44.1 (2026/01/10)
1414
==============

mycli/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,9 @@ def reconnect(self, database: str = "") -> bool:
11441144
self.logger.debug("Attempting to reconnect.")
11451145
self.echo("Reconnecting...", fg="yellow")
11461146
self.sqlexecute.conn.ping(reconnect=True)
1147+
# if a database is currently selected, set it on the conn again
1148+
if self.sqlexecute.dbname:
1149+
self.sqlexecute.conn.select_db(self.sqlexecute.dbname)
11471150
self.logger.debug("Reconnected successfully.")
11481151
self.echo("Reconnected successfully.", fg="yellow")
11491152
self.sqlexecute.reset_connection_id()

test/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import pytest
44

55
import mycli.sqlexecute
6-
from test.utils import CHARSET, HOST, PASSWORD, PORT, SSH_HOST, SSH_PORT, SSH_USER, USER, create_db, db_connection
6+
from test.utils import CHARSET, DATABASE, HOST, PASSWORD, PORT, SSH_HOST, SSH_PORT, SSH_USER, USER, create_db, db_connection
77

88

99
@pytest.fixture(scope="function")
1010
def connection():
11-
create_db("mycli_test_db")
12-
connection = db_connection("mycli_test_db")
11+
create_db(DATABASE)
12+
connection = db_connection(DATABASE)
1313
yield connection
1414

1515
connection.close()
@@ -24,7 +24,7 @@ def cursor(connection):
2424
@pytest.fixture
2525
def executor(connection):
2626
return mycli.sqlexecute.SQLExecute(
27-
database="mycli_test_db",
27+
database=DATABASE,
2828
user=USER,
2929
host=HOST,
3030
password=PASSWORD,

test/test_main.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
import click
1111
from click.testing import CliRunner
12+
from pymysql.err import OperationalError
1213

1314
from mycli.main import MyCli, cli, thanks_picker
1415
import mycli.packages.special
1516
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
1617
from mycli.sqlexecute import ServerInfo, SQLExecute
17-
from test.utils import HOST, PASSWORD, PORT, USER, dbtest, run
18+
from test.utils import DATABASE, HOST, PASSWORD, PORT, USER, dbtest, run
1819

1920
test_dir = os.path.abspath(os.path.dirname(__file__))
2021
project_dir = os.path.dirname(test_dir)
@@ -94,6 +95,41 @@ def test_ssl_mode_overrides_no_ssl(executor, capsys):
9495
assert ssl_cipher
9596

9697

98+
@dbtest
99+
def test_reconnect_database_is_selected(executor, capsys):
100+
m = MyCli()
101+
m.register_special_commands()
102+
m.sqlexecute = SQLExecute(
103+
None,
104+
USER,
105+
PASSWORD,
106+
HOST,
107+
PORT,
108+
None,
109+
None,
110+
None,
111+
None,
112+
None,
113+
None,
114+
None,
115+
None,
116+
None,
117+
None,
118+
)
119+
try:
120+
next(m.sqlexecute.run(f"use {DATABASE}"))
121+
next(m.sqlexecute.run(f"kill {m.sqlexecute.connection_id}"))
122+
except OperationalError:
123+
pass # expected as the connection was killed
124+
except Exception as e:
125+
raise e
126+
m.reconnect()
127+
try:
128+
next(m.sqlexecute.run("show tables")).results.fetchall()
129+
except Exception as e:
130+
raise e
131+
132+
97133
@dbtest
98134
def test_reconnect_no_database(executor, capsys):
99135
m = MyCli()

test/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mycli.main import special
1313

14+
DATABASE = "mycli_test_db"
1415
PASSWORD = os.getenv("PYTEST_PASSWORD")
1516
USER = os.getenv("PYTEST_USER", "root")
1617
HOST = os.getenv("PYTEST_HOST", "localhost")

0 commit comments

Comments
 (0)