66import time
77import rich
88import click
9+ import logging
910import sqlite3
1011import getpass
1112import datetime
3940
4041table_headers = ("_" , "name" , "type" , "_" , "_" , "_" )
4142
43+ logging .basicConfig (
44+ format = "%(asctime)s - %(levelname)s : %(message)s" ,
45+ datefmt = "%d-%b-%Y %H:%M:%S" ,
46+ level = logging .INFO ,
47+ )
48+
4249
4350def cli_error_handler (func ):
4451 """Decorator for handling exceptions accordingly"""
@@ -127,7 +134,7 @@ class TextToSql:
127134 def __init__ (self , db_manager : Sqlite3Manager ):
128135 """Initializes `TextToSql`"""
129136 try :
130- import pytgpt .auto as auto
137+ from pytgpt .auto import AUTO
131138 except ImportError :
132139 raise Exception (
133140 "Looks like pytgpt isn't installed. Install it before using TextToSql - "
@@ -136,7 +143,7 @@ def __init__(self, db_manager: Sqlite3Manager):
136143 history_file = Path .home () / ".sqlite-cli-manager-ai-chat-history.txt"
137144 if history_file .exists ():
138145 os .remove (history_file )
139- self .ai = auto . AUTO (filepath = history_file )
146+ self .ai = AUTO (filepath = str ( history_file ) )
140147 assert isinstance (
141148 db_manager , Sqlite3Manager
142149 ), f"db_manager must be an instance of { Sqlite3Manager } not { type (db_manager )} "
@@ -170,7 +177,7 @@ def context_prompt(self) -> str:
170177 """
171178 \n
172179 For example:
173- User: List top 10 entries in the Linux table where distro contains letter 'a'
180+ User: List first 10 entries in the Linux table where distro contains letter 'a'
174181 LLM : {SELECT * FROM Linux WHERE distro LIKE '%a%';}
175182
176183 User : Remove entries from table Linux whose id is greater than 10.
@@ -179,7 +186,7 @@ def context_prompt(self) -> str:
179186 If the user's request IS UNDOUBTEDBLY INCOMPLETE, seek clarification.
180187 For example:
181188 User: Add column to Linux table.
182- LLM: What kind of data will be stored in the column and suggest column name if possible?
189+ LLM: Describe the data to be stored in the column and suggest column name if possible?
183190 User: The column will be storing maintainance status of the linux distros.
184191 LLM: {ALTER TABLE Linux ADD COLUMN is_maintained BOOLEAN;}
185192
@@ -211,7 +218,8 @@ def generate(self, prompt: str):
211218 """Main method"""
212219 self .ai .intro = self .context_prompt
213220 assert prompt , f"Prompt cannot be null!"
214- return self .process_response (self .ai .chat (prompt ))
221+ ai_response = self .ai .chat (prompt )
222+ return self .process_response (ai_response )
215223
216224
217225class HistoryCompletions (Completer ):
@@ -475,12 +483,16 @@ def do_columns(self, line):
475483 else :
476484 click .secho ("Table name is required." , fg = "yellow" )
477485
486+ def do_redo (self , line ):
487+ """Re-run previous sql command"""
488+ history = self .completer_session .history .get_strings ()
489+ return self .default (history [- 2 ], prompt_confirmation = True )
490+
478491 @cli_error_handler
479492 def default (self , line : str , prompt_confirmation : bool = False , ai_generated = False ):
480493 """Run sql statemnt against database"""
481-
482494 if line .startswith ("./" ):
483- self .do_sys (line [2 :])
495+ self .do_sys (line [2 :]. strip () )
484496 return
485497
486498 elif line .startswith ("!" ):
@@ -495,20 +507,22 @@ def default(self, line: str, prompt_confirmation: bool = False, ai_generated=Fal
495507 return
496508
497509 elif line .startswith ("/sql" ):
498- line = [line [4 :]]
510+ line = [line [4 :]. strip () ]
499511 elif line .startswith ("/ai" ):
500- line = TextToSql (self .db_manager ).generate (line [3 :])
501- ai_generated = prompt_confirmation = True
512+ line = TextToSql (self .db_manager ).generate (line [3 :].strip ())
513+ prompt_confirmation = True
514+ ai_generated = True
502515 elif self .ai :
503516 line = self .text_to_sql .generate (line )
504- ai_generated = prompt_confirmation = True
517+ ai_generated = True
518+ prompt_confirmation = True
505519 else :
506520 line = [line ]
507521 self .__start_time = time .time ()
508522 for sql_statement in line :
509523 if (
510- not self . yes
511- and prompt_confirmation
524+ prompt_confirmation
525+ and not self . yes
512526 and not click .confirm ("[Exc] - " + sql_statement )
513527 ):
514528 continue
@@ -568,8 +582,13 @@ def stdout_data(
568582 table .add_column ("Index" , justify = "center" )
569583
570584 def add_headers (header_values : list [str ]):
571- for header in header_values :
572- table .add_column (header )
585+ if data and len (header_values ) == len (data [0 ]):
586+ for header in header_values :
587+ table .add_column (header )
588+ else :
589+ logging .debug (
590+ f"No data to be displayed or length of data and headers don't match."
591+ )
573592
574593 if headers :
575594 add_headers (headers )
@@ -586,7 +605,7 @@ def add_headers(header_values: list[str]):
586605 specific_column_names_string = re .findall (
587606 r"^select\s+([\w_,\s]+)\s+from.+" , * re_args
588607 )
589- if re .match (r"^select\s+\.*" , * re_args ):
608+ if re .match (r"^select\s+\* .*" , * re_args ):
590609 table_name = re .findall (r".+from\s+([\w_]+).*" , * re_args )
591610 if table_name :
592611 tbl_name = table_name [0 ]
@@ -615,7 +634,7 @@ def add_headers(header_values: list[str]):
615634 else :
616635 for index , entry in enumerate (data ):
617636 table .add_row (* [str (index )] + [str (token ) for token in entry ])
618- rich .print (table )
637+ rich .print (table )
619638
620639 @staticmethod
621640 @click .command ()
0 commit comments