3535get_arg = lambda e : e .args [1 ] if e .args and len (e .args ) > 1 else str (e )
3636"""An ugly anonymous function to extract exception message"""
3737
38+ table_column_headers = ("cid" , "name" , "type" , "notnull" , "default" , "pk" )
39+
40+ table_headers = ("_" , "name" , "type" , "_" , "_" , "_" )
41+
3842
3943def cli_error_handler (func ):
4044 """Decorator for handling exceptions accordingly"""
@@ -68,19 +72,29 @@ def execute_sql_command(
6872 ) -> t .Tuple [t .Any ]:
6973 """Run sql statements against database"""
7074 try :
71- cursor = sqlite3 . Cursor ( self .db_connection )
75+ cursor = self .db_connection . cursor ( )
7276 cursor .execute (statement )
7377 if commit :
7478 self .commit ()
75- return (True , cursor .fetchall ())
79+ resp = (True , cursor .fetchall ())
7680 except Exception as e :
77- return (False , e )
81+ resp = (False , e )
7882 finally :
7983 cursor .close ()
84+ return resp
8085
81- def tables (self ):
86+ def tables (self , tbl_names_only : bool = False ):
8287 """List tables available"""
83- return self .execute_sql_command ("PRAGMA table_list;" )
88+ return (
89+ [
90+ entry [0 ]
91+ for entry in self .execute_sql_command (
92+ "SELECT tbl_name FROM sqlite_schema WHERE type='table'"
93+ )[1 ]
94+ ]
95+ if tbl_names_only
96+ else self .execute_sql_command ("PRAGMA table_list;" )
97+ )
8498
8599 def table_columns (self , table : str ):
86100 """List table columns and their metadata"""
@@ -119,7 +133,10 @@ def __init__(self, db_manager: Sqlite3Manager):
119133 "Looks like pytgpt isn't installed. Install it before using TextToSql - "
120134 '"pip install python-tgpt"'
121135 )
122- self .ai = auto .AUTO (update_file = False )
136+ history_file = Path .home () / ".sqlite-cli-manager-ai-chat-history.txt"
137+ if history_file .exists ():
138+ os .remove (history_file )
139+ self .ai = auto .AUTO (filepath = history_file )
123140 assert isinstance (
124141 db_manager , Sqlite3Manager
125142 ), f"db_manager must be an instance of { Sqlite3Manager } not { type (db_manager )} "
@@ -198,14 +215,34 @@ def generate(self, prompt: str):
198215
199216
200217class HistoryCompletions (Completer ):
201- def __init__ (self , session , disable_suggestions ):
218+ def __init__ (self , session , disable_suggestions , db_manager : Sqlite3Manager ):
202219 self .session : PromptSession = session
203220 self .disable_suggestions = disable_suggestions
221+ self .db_manager = db_manager
204222
205223 def get_completions (self , document : Document , complete_event ):
206224 if self .disable_suggestions :
207225 return
208226 text = document .text
227+ processed_text = text .lower ().strip ()
228+ if processed_text .endswith ("from" ):
229+ # Suggest available table names
230+ for table in self .db_manager .tables (tbl_names_only = True ):
231+ yield Completion (text + " " + table , start_position = - len (text ))
232+
233+ elif processed_text .endswith ("where" ):
234+ # Suggest columns for a particular table
235+ db_tables = self .db_manager .tables (tbl_names_only = True )
236+ target_table = re .findall (
237+ r".+from\s([\w_]+)\s.*" , text , flags = re .IGNORECASE
238+ )
239+ if target_table and target_table [0 ] in db_tables :
240+ for column in [
241+ entry [1 ]
242+ for entry in self .db_manager .table_columns (target_table [0 ])[1 ]
243+ ]:
244+ yield Completion (text + " " + column , start_position = - len (text ))
245+
209246 history = self .session .history .get_strings ()
210247 for entry in reversed (list (set (history ))):
211248 if entry .startswith (text ):
@@ -246,8 +283,8 @@ def __init__(
246283 os .remove (history_file_path )
247284 history = FileHistory (history_file_path )
248285 self .completer_session = PromptSession (history = history )
249- self .completer_session .completer = ThreadedCompleter (
250- HistoryCompletions ( self .completer_session , disable_suggestions )
286+ self .completer_session .completer = HistoryCompletions (
287+ self .completer_session , disable_suggestions , self . db_manager
251288 )
252289 self .ai = ai
253290 if self .ai :
@@ -400,13 +437,26 @@ def do_ai(self, line):
400437 def do_schema (self , line ):
401438 """Show database schema"""
402439 success , tables = self .db_manager .schema ()
403- Commands .stdout_data (success , tables , json = self .json , color = self .color )
440+ Commands .stdout_data (
441+ success ,
442+ tables ,
443+ json = self .json ,
444+ color = self .color ,
445+ tbl = "sqlite_schema" ,
446+ db_manager = self .db_manager ,
447+ )
404448
405449 @cli_error_handler
406450 def do_tables (self , line ):
407451 """Show database tables"""
408452 success , tables = self .db_manager .tables ()
409- Commands .stdout_data (success , tables , json = self .json , color = self .color )
453+ Commands .stdout_data (
454+ success ,
455+ tables ,
456+ json = self .json ,
457+ color = self .color ,
458+ headers = table_headers ,
459+ )
410460
411461 @cli_error_handler
412462 def do_columns (self , line ):
@@ -415,7 +465,13 @@ def do_columns(self, line):
415465 columns <table-name>"""
416466 if line :
417467 success , tables = self .db_manager .table_columns (line )
418- Commands .stdout_data (success , tables , json = self .json , color = self .color )
468+ Commands .stdout_data (
469+ success ,
470+ tables ,
471+ json = self .json ,
472+ color = self .color ,
473+ headers = table_column_headers ,
474+ )
419475 else :
420476 click .secho ("Table name is required." , fg = "yellow" )
421477
@@ -459,7 +515,14 @@ def default(self, line: str, prompt_confirmation: bool = False, ai_generated=Fal
459515 if ai_generated :
460516 self .completer_session .history .append_string (sql_statement )
461517 success , response = self .db_manager .execute_sql_command (sql_statement )
462- Commands .stdout_data (success , response , json = self .json , color = self .color )
518+ Commands .stdout_data (
519+ success ,
520+ response ,
521+ json = self .json ,
522+ color = self .color ,
523+ sql_query = sql_statement ,
524+ db_manager = self .db_manager ,
525+ )
463526 self .__end_time = time .time ()
464527
465528 def do_exit (self , line ):
@@ -478,35 +541,81 @@ def stdout_data(
478541 color : str = "cyan" ,
479542 title : str = None ,
480543 json : bool = False ,
544+ headers : list [str ] = None ,
545+ sql_query : str = None ,
546+ db_manager : Sqlite3Manager = None ,
547+ tbl : str = None ,
481548 ):
482- """Stdout info .
549+ """Stdout table data if any .
483550
484551 Args:
485552 data (t.List[t.Tuple[t.Any]]):
486553 color (str, optional):. Defaults to 'cyan'.
487554 title (str, optional): Table title. Defaults to None.
488555 json (bool, optional): Output in Json format. Defaults to False.
556+ sql_query (str, optional): Sql statement used to make the query.
557+ db_manager (Sqlite3Manager, optional)
558+ tbl (str, optional): Table name where * has been sourced from.
489559 """
490560
491561 if not success :
492562 raise data
493563
494564 elif data and data [0 ]:
565+
566+ table = Table (title = title , show_lines = True , show_header = True , style = color )
567+ ref_data = data [0 ]
568+ table .add_column ("Index" , justify = "center" )
569+
570+ def add_headers (header_values : list [str ]):
571+ for header in header_values :
572+ table .add_column (header )
573+
574+ if headers :
575+ add_headers (headers )
576+
577+ elif tbl and db_manager :
578+ # extract column names
579+ success , entries = db_manager .table_columns (tbl )
580+ if success :
581+ add_headers ([entry [1 ] for entry in entries ])
582+
583+ elif sql_query and db_manager :
584+ re_args = (sql_query , re .IGNORECASE )
585+ if re .match (r"^select.*" , * re_args ):
586+ specific_column_names_string = re .findall (
587+ r"^select\s+([\w_,\s]+)\s+from.+" , * re_args
588+ )
589+ if re .match (r"^select\s+\.*" , * re_args ):
590+ table_name = re .findall (r".+from\s+([\w_]+).*" , * re_args )
591+ if table_name :
592+ tbl_name = table_name [0 ]
593+ success , entries = db_manager .table_columns (tbl_name )
594+ if success :
595+ headers = [entry [1 ] for entry in entries ]
596+
597+ elif specific_column_names_string :
598+ headers = re .findall (r"\w+" , specific_column_names_string [0 ])
599+
600+ if headers :
601+ add_headers (headers )
602+ else :
603+ add_headers ([f"Col. { x + 1 } " for x in range (len (ref_data ))])
604+
495605 if json :
496606 entry_items = {}
497607 for index , entry in enumerate (data ):
608+ if headers :
609+ entry = dict (zip (headers , entry ))
610+
498611 entry_items [index ] = entry
499612 rich .print_json (data = entry_items )
500- return
501613
502- table = Table (title = title , show_lines = True , show_header = True , style = color )
503- ref_data = data [0 ]
504- table .add_column ("Index" , justify = "center" )
505- for x in range (len (ref_data )):
506- table .add_column (f"Col. { x + 1 } " )
507- for index , entry in enumerate (data ):
508- table .add_row (* [str (index )] + [str (token ) for token in entry ])
509- rich .print (table )
614+ return
615+ else :
616+ for index , entry in enumerate (data ):
617+ table .add_row (* [str (index )] + [str (token ) for token in entry ])
618+ rich .print (table )
510619
511620 @staticmethod
512621 @click .command ()
@@ -518,7 +627,7 @@ def show_tables(database, json):
518627 """List tables contained in the database"""
519628 db_manager = Sqlite3Manager (database )
520629 success , tables = db_manager .tables ()
521- Commands .stdout_data (success , tables , json = json )
630+ Commands .stdout_data (success , tables , json = json , headers = table_headers )
522631
523632 @staticmethod
524633 @click .command ()
@@ -531,14 +640,16 @@ def show_columns(database, table, json):
531640 """List columns for a particular table"""
532641 db_manager = Sqlite3Manager (database )
533642 success , tables = db_manager .table_columns (table )
534- Commands .stdout_data (success , tables , json = json )
643+ Commands .stdout_data (success , tables , json = json , headers = table_column_headers )
535644
536645 @staticmethod
537646 @click .command ()
538647 @click .argument (
539648 "database" , type = click .Path (exists = True , dir_okay = False , resolve_path = True )
540649 )
541- @click .option ("-s" , "--sql" , multiple = True , help = "Sql statements" , required = True )
650+ @click .option (
651+ "-s" , "--sql" , multiple = True , help = "Sql statement or prompt" , required = True
652+ )
542653 @click .option (
543654 "-i" , "--ai" , is_flag = True , help = "Generate sql statements from prompt by AI"
544655 )
@@ -556,7 +667,9 @@ def execute(database, sql, ai, json, quiet):
556667 for sql_statement in sql if not ai else ai_gen_sql_statements :
557668 success , tables = db_manager .execute_sql_command (sql_statement )
558669 if not quiet :
559- Commands .stdout_data (success , tables , json = json )
670+ Commands .stdout_data (
671+ success , tables , json = json , sql_query = sql_statement
672+ )
560673
561674 @staticmethod
562675 @click .command ()
0 commit comments