1- # type: ignore
1+ from __future__ import annotations
22
3+ import datetime
34import enum
45import logging
56import re
7+ import ssl
8+ from typing import Any , Generator
69
710import pymysql
811from pymysql .constants import FIELD_TYPE
912from pymysql .converters import conversions , convert_date , convert_datetime , convert_timedelta , decoders
13+ from pymysql .cursors import Cursor
1014
11- from mycli .packages import special
15+ from mycli .packages .special import iocommands
16+ from mycli .packages .special .main import CommandNotFound , execute
1217
1318try :
1419 import paramiko # noqa: F401
@@ -34,13 +39,13 @@ class ServerSpecies(enum.Enum):
3439
3540
3641class ServerInfo :
37- def __init__ (self , species , version_str ) :
42+ def __init__ (self , species : ServerSpecies | None , version_str : str ) -> None :
3843 self .species = species
3944 self .version_str = version_str
4045 self .version = self .calc_mysql_version_value (version_str )
4146
4247 @staticmethod
43- def calc_mysql_version_value (version_str ) -> int :
48+ def calc_mysql_version_value (version_str : str ) -> int :
4449 if not version_str or not isinstance (version_str , str ):
4550 return 0
4651 try :
@@ -51,7 +56,7 @@ def calc_mysql_version_value(version_str) -> int:
5156 return int (major ) * 10_000 + int (minor ) * 100 + int (patch )
5257
5358 @classmethod
54- def from_version_string (cls , version_string ) :
59+ def from_version_string (cls , version_string : str ) -> ServerInfo :
5560 if not version_string :
5661 return cls (ServerSpecies .MySQL , "" )
5762
@@ -73,7 +78,7 @@ def from_version_string(cls, version_string):
7378
7479 return cls (detected_species , parsed_version )
7580
76- def __str__ (self ):
81+ def __str__ (self ) -> str :
7782 if self .species :
7883 return f"{ self .species .value } { self .version_str } "
7984 else :
@@ -100,22 +105,22 @@ class SQLExecute:
100105
101106 def __init__ (
102107 self ,
103- database ,
104- user ,
105- password ,
106- host ,
107- port ,
108- socket ,
109- charset ,
110- local_infile ,
111- ssl ,
112- ssh_user ,
113- ssh_host ,
114- ssh_port ,
115- ssh_password ,
116- ssh_key_filename ,
117- init_command = None ,
118- ):
108+ database : str | None ,
109+ user : str | None ,
110+ password : str | None ,
111+ host : str | None ,
112+ port : int | None ,
113+ socket : str | None ,
114+ charset : str | None ,
115+ local_infile : str | None ,
116+ ssl : dict [ str , Any ] | None ,
117+ ssh_user : str | None ,
118+ ssh_host : str | None ,
119+ ssh_port : int | None ,
120+ ssh_password : str | None ,
121+ ssh_key_filename : str | None ,
122+ init_command : str | None = None ,
123+ ) -> None :
119124 self .dbname = database
120125 self .user = user
121126 self .password = password
@@ -125,8 +130,8 @@ def __init__(
125130 self .charset = charset
126131 self .local_infile = local_infile
127132 self .ssl = ssl
128- self .server_info = None
129- self .connection_id = None
133+ self .server_info : ServerInfo | None = None
134+ self .connection_id : int | None = None
130135 self .ssh_user = ssh_user
131136 self .ssh_host = ssh_host
132137 self .ssh_port = ssh_port
@@ -213,7 +218,7 @@ def connect(
213218 defer_connect = True
214219
215220 client_flag = pymysql .constants .CLIENT .INTERACTIVE
216- if init_command and len (list (special .split_queries (init_command ))) > 1 :
221+ if init_command and len (list (iocommands .split_queries (init_command ))) > 1 :
217222 client_flag |= pymysql .constants .CLIENT .MULTI_STATEMENTS
218223
219224 ssl_context = None
@@ -277,7 +282,7 @@ def connect(
277282 self .reset_connection_id ()
278283 self .server_info = ServerInfo .from_version_string (conn .server_version )
279284
280- def run (self , statement ) :
285+ def run (self , statement : str ) -> Generator [ tuple , None , None ] :
281286 """Execute the sql in the database and return the results. The results
282287 are a list of tuples. Each tuple has 4 values
283288 (title, rows, headers, status).
@@ -294,26 +299,26 @@ def run(self, statement):
294299 if statement .startswith ("\\ fs" ):
295300 components = [statement ]
296301 else :
297- components = special .split_queries (statement )
302+ components = iocommands .split_queries (statement )
298303
299304 for sql in components :
300305 # \G is treated specially since we have to set the expanded output.
301306 if sql .endswith ("\\ G" ):
302- special .set_expanded_output (True )
307+ iocommands .set_expanded_output (True )
303308 sql = sql [:- 2 ].strip ()
304309 # \g is treated specially since we might want collapsed output when
305310 # auto vertical output is enabled
306311 elif sql .endswith ('\\ g' ):
307- special .set_expanded_output (False )
308- special .set_forced_horizontal_output (True )
312+ iocommands .set_expanded_output (False )
313+ iocommands .set_forced_horizontal_output (True )
309314 sql = sql [:- 2 ].strip ()
310315
311316 cur = self .conn .cursor ()
312317 try : # Special command
313318 _logger .debug ("Trying a dbspecial command. sql: %r" , sql )
314- for result in special . execute (cur , sql ):
319+ for result in execute (cur , sql ):
315320 yield result
316- except special . CommandNotFound : # Regular SQL
321+ except CommandNotFound : # Regular SQL
317322 _logger .debug ("Regular sql statement. sql: %r" , sql )
318323 cur .execute (sql )
319324 while True :
@@ -325,23 +330,24 @@ def run(self, statement):
325330 if not cur .nextset () or (not cur .rowcount and cur .description is None ):
326331 break
327332
328- def get_result (self , cursor ) :
333+ def get_result (self , cursor : Cursor ) -> tuple :
329334 """Get the current result's data from the cursor."""
330335 title = headers = None
331336
332337 # cursor.description is not None for queries that return result sets,
333338 # e.g. SELECT or SHOW.
334- if cursor .description is not None :
339+ if cursor .description :
335340 headers = [x [0 ] for x in cursor .description ]
336- status = "{0} row{1} in set"
341+ plural = '' if cursor .rowcount == 1 else 's'
342+ status = f'{ cursor .rowcount } row{ plural } in set'
337343 else :
338344 _logger .debug ("No rows in result." )
339- status = "Query OK, {0} row{1} affected"
340- status = status . format ( cursor . rowcount , "" if cursor .rowcount == 1 else "s" )
345+ plural = '' if cursor . rowcount == 1 else 's'
346+ status = f'Query OK, { cursor .rowcount } row { plural } affected'
341347
342348 return (title , cursor if cursor .description else None , headers , status )
343349
344- def tables (self ):
350+ def tables (self ) -> Generator [ tuple [ str ], None , None ] :
345351 """Yields table names"""
346352
347353 with self .conn .cursor () as cur :
@@ -350,21 +356,21 @@ def tables(self):
350356 for row in cur :
351357 yield row
352358
353- def table_columns (self ):
359+ def table_columns (self ) -> Generator [ tuple [ str , str ], None , None ] :
354360 """Yields (table name, column name) pairs"""
355361 with self .conn .cursor () as cur :
356362 _logger .debug ("Columns Query. sql: %r" , self .table_columns_query )
357363 cur .execute (self .table_columns_query % self .dbname )
358364 for row in cur :
359365 yield row
360366
361- def databases (self ):
367+ def databases (self ) -> list [ str ] :
362368 with self .conn .cursor () as cur :
363369 _logger .debug ("Databases Query. sql: %r" , self .databases_query )
364370 cur .execute (self .databases_query )
365371 return [x [0 ] for x in cur .fetchall ()]
366372
367- def functions (self ):
373+ def functions (self ) -> Generator [ tuple [ str , str ], None , None ] :
368374 """Yields tuples of (schema_name, function_name)"""
369375
370376 with self .conn .cursor () as cur :
@@ -373,47 +379,50 @@ def functions(self):
373379 for row in cur :
374380 yield row
375381
376- def show_candidates (self ):
382+ def show_candidates (self ) -> Generator [ tuple , None , None ] :
377383 with self .conn .cursor () as cur :
378384 _logger .debug ("Show Query. sql: %r" , self .show_candidates_query )
379385 try :
380386 cur .execute (self .show_candidates_query )
381387 except pymysql .DatabaseError as e :
382388 _logger .error ("No show completions due to %r" , e )
383- yield ""
389+ yield ()
384390 else :
385391 for row in cur :
386392 yield (row [0 ].split (None , 1 )[- 1 ],)
387393
388- def users (self ):
394+ def users (self ) -> Generator [ tuple , None , None ] :
389395 with self .conn .cursor () as cur :
390396 _logger .debug ("Users Query. sql: %r" , self .users_query )
391397 try :
392398 cur .execute (self .users_query )
393399 except pymysql .DatabaseError as e :
394400 _logger .error ("No user completions due to %r" , e )
395- yield ""
401+ yield ()
396402 else :
397403 for row in cur :
398404 yield row
399405
400- def now (self ):
406+ def now (self ) -> datetime . datetime :
401407 with self .conn .cursor () as cur :
402408 _logger .debug ("Now Query. sql: %r" , self .now_query )
403409 cur .execute (self .now_query )
404- return cur .fetchone ()[0 ]
410+ if one := cur .fetchone ():
411+ return one [0 ]
412+ else :
413+ return datetime .datetime .now ()
405414
406- def get_connection_id (self ):
415+ def get_connection_id (self ) -> int | None :
407416 if not self .connection_id :
408417 self .reset_connection_id ()
409418 return self .connection_id
410419
411- def reset_connection_id (self ):
420+ def reset_connection_id (self ) -> None :
412421 # Remember current connection id
413422 _logger .debug ("Get current connection id" )
414423 try :
415424 res = self .run ("select connection_id()" )
416- for title , cur , headers , status in res :
425+ for _title , cur , _headers , _status in res :
417426 self .connection_id = cur .fetchone ()[0 ]
418427 except Exception as e :
419428 # See #1054
@@ -422,13 +431,11 @@ def reset_connection_id(self):
422431 else :
423432 _logger .debug ("Current connection id: %s" , self .connection_id )
424433
425- def change_db (self , db ) :
434+ def change_db (self , db : str ) -> None :
426435 self .conn .select_db (db )
427436 self .dbname = db
428437
429- def _create_ssl_ctx (self , sslp ):
430- import ssl
431-
438+ def _create_ssl_ctx (self , sslp : dict ) -> ssl .SSLContext :
432439 ca = sslp .get ("ca" )
433440 capath = sslp .get ("capath" )
434441 hasnoca = ca is None and capath is None
0 commit comments