Skip to content

Commit be58c66

Browse files
committed
add typehints to sqlexecute.py
* add typehints * import annotations for Python 3.9 compatibility * always import ssl (looks like this was a Python 2.x compat trick) * import iocommands, and import from special.main, instead of the toplevel "special" * include check for cursor.description equaling the empty string * use f-strings for result status feedback * yield empty tuples instead of empty strings for generators which yield tuples * check whether the now() query returned a value and return a native Python datetime if not * prefix underscores to some unused variables
1 parent 26dfd6d commit be58c66

File tree

1 file changed

+61
-54
lines changed

1 file changed

+61
-54
lines changed

mycli/sqlexecute.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
# type: ignore
1+
from __future__ import annotations
22

3+
import datetime
34
import enum
45
import logging
56
import re
7+
import ssl
8+
from typing import Any, Generator
69

710
import pymysql
811
from pymysql.constants import FIELD_TYPE
912
from pymysql.converters import conversions, convert_date, convert_datetime, convert_timedelta, decoders
13+
from pymysql.cursors import Cursor
1014

11-
from mycli.packages import special
15+
from mycli.packages.special import iocommands
16+
from mycli.packages.special.main import CommandNotFound, execute
1217

1318
try:
1419
import paramiko # noqa: F401
@@ -34,13 +39,13 @@ class ServerSpecies(enum.Enum):
3439

3540

3641
class ServerInfo:
37-
def __init__(self, species, version_str):
42+
def __init__(self, species: ServerSpecies | None, version_str: str) -> None:
3843
self.species = species
3944
self.version_str = version_str
4045
self.version = self.calc_mysql_version_value(version_str)
4146

4247
@staticmethod
43-
def calc_mysql_version_value(version_str) -> int:
48+
def calc_mysql_version_value(version_str: str) -> int:
4449
if not version_str or not isinstance(version_str, str):
4550
return 0
4651
try:
@@ -51,7 +56,7 @@ def calc_mysql_version_value(version_str) -> int:
5156
return int(major) * 10_000 + int(minor) * 100 + int(patch)
5257

5358
@classmethod
54-
def from_version_string(cls, version_string):
59+
def from_version_string(cls, version_string: str) -> ServerInfo:
5560
if not version_string:
5661
return cls(ServerSpecies.MySQL, "")
5762

@@ -73,7 +78,7 @@ def from_version_string(cls, version_string):
7378

7479
return cls(detected_species, parsed_version)
7580

76-
def __str__(self):
81+
def __str__(self) -> str:
7782
if self.species:
7883
return f"{self.species.value} {self.version_str}"
7984
else:
@@ -100,22 +105,22 @@ class SQLExecute:
100105

101106
def __init__(
102107
self,
103-
database,
104-
user,
105-
password,
106-
host,
107-
port,
108-
socket,
109-
charset,
110-
local_infile,
111-
ssl,
112-
ssh_user,
113-
ssh_host,
114-
ssh_port,
115-
ssh_password,
116-
ssh_key_filename,
117-
init_command=None,
118-
):
108+
database: str | None,
109+
user: str | None,
110+
password: str | None,
111+
host: str | None,
112+
port: int | None,
113+
socket: str | None,
114+
charset: str | None,
115+
local_infile: str | None,
116+
ssl: dict[str, Any] | None,
117+
ssh_user: str | None,
118+
ssh_host: str | None,
119+
ssh_port: int | None,
120+
ssh_password: str | None,
121+
ssh_key_filename: str | None,
122+
init_command: str | None = None,
123+
) -> None:
119124
self.dbname = database
120125
self.user = user
121126
self.password = password
@@ -125,8 +130,8 @@ def __init__(
125130
self.charset = charset
126131
self.local_infile = local_infile
127132
self.ssl = ssl
128-
self.server_info = None
129-
self.connection_id = None
133+
self.server_info: ServerInfo | None = None
134+
self.connection_id: int | None = None
130135
self.ssh_user = ssh_user
131136
self.ssh_host = ssh_host
132137
self.ssh_port = ssh_port
@@ -213,7 +218,7 @@ def connect(
213218
defer_connect = True
214219

215220
client_flag = pymysql.constants.CLIENT.INTERACTIVE
216-
if init_command and len(list(special.split_queries(init_command))) > 1:
221+
if init_command and len(list(iocommands.split_queries(init_command))) > 1:
217222
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
218223

219224
ssl_context = None
@@ -277,7 +282,7 @@ def connect(
277282
self.reset_connection_id()
278283
self.server_info = ServerInfo.from_version_string(conn.server_version)
279284

280-
def run(self, statement):
285+
def run(self, statement: str) -> Generator[tuple, None, None]:
281286
"""Execute the sql in the database and return the results. The results
282287
are a list of tuples. Each tuple has 4 values
283288
(title, rows, headers, status).
@@ -294,26 +299,26 @@ def run(self, statement):
294299
if statement.startswith("\\fs"):
295300
components = [statement]
296301
else:
297-
components = special.split_queries(statement)
302+
components = iocommands.split_queries(statement)
298303

299304
for sql in components:
300305
# \G is treated specially since we have to set the expanded output.
301306
if sql.endswith("\\G"):
302-
special.set_expanded_output(True)
307+
iocommands.set_expanded_output(True)
303308
sql = sql[:-2].strip()
304309
# \g is treated specially since we might want collapsed output when
305310
# auto vertical output is enabled
306311
elif sql.endswith('\\g'):
307-
special.set_expanded_output(False)
308-
special.set_forced_horizontal_output(True)
312+
iocommands.set_expanded_output(False)
313+
iocommands.set_forced_horizontal_output(True)
309314
sql = sql[:-2].strip()
310315

311316
cur = self.conn.cursor()
312317
try: # Special command
313318
_logger.debug("Trying a dbspecial command. sql: %r", sql)
314-
for result in special.execute(cur, sql):
319+
for result in execute(cur, sql):
315320
yield result
316-
except special.CommandNotFound: # Regular SQL
321+
except CommandNotFound: # Regular SQL
317322
_logger.debug("Regular sql statement. sql: %r", sql)
318323
cur.execute(sql)
319324
while True:
@@ -325,23 +330,24 @@ def run(self, statement):
325330
if not cur.nextset() or (not cur.rowcount and cur.description is None):
326331
break
327332

328-
def get_result(self, cursor):
333+
def get_result(self, cursor: Cursor) -> tuple:
329334
"""Get the current result's data from the cursor."""
330335
title = headers = None
331336

332337
# cursor.description is not None for queries that return result sets,
333338
# e.g. SELECT or SHOW.
334-
if cursor.description is not None:
339+
if cursor.description:
335340
headers = [x[0] for x in cursor.description]
336-
status = "{0} row{1} in set"
341+
plural = '' if cursor.rowcount == 1 else 's'
342+
status = f'{cursor.rowcount} row{plural} in set'
337343
else:
338344
_logger.debug("No rows in result.")
339-
status = "Query OK, {0} row{1} affected"
340-
status = status.format(cursor.rowcount, "" if cursor.rowcount == 1 else "s")
345+
plural = '' if cursor.rowcount == 1 else 's'
346+
status = f'Query OK, {cursor.rowcount} row{plural} affected'
341347

342348
return (title, cursor if cursor.description else None, headers, status)
343349

344-
def tables(self):
350+
def tables(self) -> Generator[tuple[str], None, None]:
345351
"""Yields table names"""
346352

347353
with self.conn.cursor() as cur:
@@ -350,21 +356,21 @@ def tables(self):
350356
for row in cur:
351357
yield row
352358

353-
def table_columns(self):
359+
def table_columns(self) -> Generator[tuple[str, str], None, None]:
354360
"""Yields (table name, column name) pairs"""
355361
with self.conn.cursor() as cur:
356362
_logger.debug("Columns Query. sql: %r", self.table_columns_query)
357363
cur.execute(self.table_columns_query % self.dbname)
358364
for row in cur:
359365
yield row
360366

361-
def databases(self):
367+
def databases(self) -> list[str]:
362368
with self.conn.cursor() as cur:
363369
_logger.debug("Databases Query. sql: %r", self.databases_query)
364370
cur.execute(self.databases_query)
365371
return [x[0] for x in cur.fetchall()]
366372

367-
def functions(self):
373+
def functions(self) -> Generator[tuple[str, str], None, None]:
368374
"""Yields tuples of (schema_name, function_name)"""
369375

370376
with self.conn.cursor() as cur:
@@ -373,47 +379,50 @@ def functions(self):
373379
for row in cur:
374380
yield row
375381

376-
def show_candidates(self):
382+
def show_candidates(self) -> Generator[tuple, None, None]:
377383
with self.conn.cursor() as cur:
378384
_logger.debug("Show Query. sql: %r", self.show_candidates_query)
379385
try:
380386
cur.execute(self.show_candidates_query)
381387
except pymysql.DatabaseError as e:
382388
_logger.error("No show completions due to %r", e)
383-
yield ""
389+
yield ()
384390
else:
385391
for row in cur:
386392
yield (row[0].split(None, 1)[-1],)
387393

388-
def users(self):
394+
def users(self) -> Generator[tuple, None, None]:
389395
with self.conn.cursor() as cur:
390396
_logger.debug("Users Query. sql: %r", self.users_query)
391397
try:
392398
cur.execute(self.users_query)
393399
except pymysql.DatabaseError as e:
394400
_logger.error("No user completions due to %r", e)
395-
yield ""
401+
yield ()
396402
else:
397403
for row in cur:
398404
yield row
399405

400-
def now(self):
406+
def now(self) -> datetime.datetime:
401407
with self.conn.cursor() as cur:
402408
_logger.debug("Now Query. sql: %r", self.now_query)
403409
cur.execute(self.now_query)
404-
return cur.fetchone()[0]
410+
if one := cur.fetchone():
411+
return one[0]
412+
else:
413+
return datetime.datetime.now()
405414

406-
def get_connection_id(self):
415+
def get_connection_id(self) -> int | None:
407416
if not self.connection_id:
408417
self.reset_connection_id()
409418
return self.connection_id
410419

411-
def reset_connection_id(self):
420+
def reset_connection_id(self) -> None:
412421
# Remember current connection id
413422
_logger.debug("Get current connection id")
414423
try:
415424
res = self.run("select connection_id()")
416-
for title, cur, headers, status in res:
425+
for _title, cur, _headers, _status in res:
417426
self.connection_id = cur.fetchone()[0]
418427
except Exception as e:
419428
# See #1054
@@ -422,13 +431,11 @@ def reset_connection_id(self):
422431
else:
423432
_logger.debug("Current connection id: %s", self.connection_id)
424433

425-
def change_db(self, db):
434+
def change_db(self, db: str) -> None:
426435
self.conn.select_db(db)
427436
self.dbname = db
428437

429-
def _create_ssl_ctx(self, sslp):
430-
import ssl
431-
438+
def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext:
432439
ca = sslp.get("ca")
433440
capath = sslp.get("capath")
434441
hasnoca = ca is None and capath is None

0 commit comments

Comments
 (0)