Skip to content

Commit 866f919

Browse files
committed
feat: Add logging module
fix: Update pytgpt import fix: Modify history file path handling fix: Correct SQL query syntax fix: Improve AI-generated command processing fix: Add redo functionality fix: Enhance error handling for system commands fix: Improve table display logic fix: Fix regex pattern for SELECT statements fix: Update requirements.txt
1 parent 0ce2987 commit 866f919

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

manager.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import rich
88
import click
9+
import logging
910
import sqlite3
1011
import getpass
1112
import datetime
@@ -39,6 +40,12 @@
3940

4041
table_headers = ("_", "name", "type", "_", "_", "_")
4142

43+
logging.basicConfig(
44+
format="%(asctime)s - %(levelname)s : %(message)s",
45+
datefmt="%d-%b-%Y %H:%M:%S",
46+
level=logging.INFO,
47+
)
48+
4249

4350
def cli_error_handler(func):
4451
"""Decorator for handling exceptions accordingly"""
@@ -127,7 +134,7 @@ class TextToSql:
127134
def __init__(self, db_manager: Sqlite3Manager):
128135
"""Initializes `TextToSql`"""
129136
try:
130-
import pytgpt.auto as auto
137+
from pytgpt.auto import AUTO
131138
except ImportError:
132139
raise Exception(
133140
"Looks like pytgpt isn't installed. Install it before using TextToSql - "
@@ -136,7 +143,7 @@ def __init__(self, db_manager: Sqlite3Manager):
136143
history_file = Path.home() / ".sqlite-cli-manager-ai-chat-history.txt"
137144
if history_file.exists():
138145
os.remove(history_file)
139-
self.ai = auto.AUTO(filepath=history_file)
146+
self.ai = AUTO(filepath=str(history_file))
140147
assert isinstance(
141148
db_manager, Sqlite3Manager
142149
), f"db_manager must be an instance of {Sqlite3Manager} not {type(db_manager)}"
@@ -170,7 +177,7 @@ def context_prompt(self) -> str:
170177
"""
171178
\n
172179
For example:
173-
User: List top 10 entries in the Linux table where distro contains letter 'a'
180+
User: List first 10 entries in the Linux table where distro contains letter 'a'
174181
LLM : {SELECT * FROM Linux WHERE distro LIKE '%a%';}
175182
176183
User : Remove entries from table Linux whose id is greater than 10.
@@ -179,7 +186,7 @@ def context_prompt(self) -> str:
179186
If the user's request IS UNDOUBTEDBLY INCOMPLETE, seek clarification.
180187
For example:
181188
User: Add column to Linux table.
182-
LLM: What kind of data will be stored in the column and suggest column name if possible?
189+
LLM: Describe the data to be stored in the column and suggest column name if possible?
183190
User: The column will be storing maintainance status of the linux distros.
184191
LLM: {ALTER TABLE Linux ADD COLUMN is_maintained BOOLEAN;}
185192
@@ -211,7 +218,8 @@ def generate(self, prompt: str):
211218
"""Main method"""
212219
self.ai.intro = self.context_prompt
213220
assert prompt, f"Prompt cannot be null!"
214-
return self.process_response(self.ai.chat(prompt))
221+
ai_response = self.ai.chat(prompt)
222+
return self.process_response(ai_response)
215223

216224

217225
class HistoryCompletions(Completer):
@@ -475,12 +483,16 @@ def do_columns(self, line):
475483
else:
476484
click.secho("Table name is required.", fg="yellow")
477485

486+
def do_redo(self, line):
487+
"""Re-run previous sql command"""
488+
history = self.completer_session.history.get_strings()
489+
return self.default(history[-2], prompt_confirmation=True)
490+
478491
@cli_error_handler
479492
def default(self, line: str, prompt_confirmation: bool = False, ai_generated=False):
480493
"""Run sql statemnt against database"""
481-
482494
if line.startswith("./"):
483-
self.do_sys(line[2:])
495+
self.do_sys(line[2:].strip())
484496
return
485497

486498
elif line.startswith("!"):
@@ -495,20 +507,22 @@ def default(self, line: str, prompt_confirmation: bool = False, ai_generated=Fal
495507
return
496508

497509
elif line.startswith("/sql"):
498-
line = [line[4:]]
510+
line = [line[4:].strip()]
499511
elif line.startswith("/ai"):
500-
line = TextToSql(self.db_manager).generate(line[3:])
501-
ai_generated = prompt_confirmation = True
512+
line = TextToSql(self.db_manager).generate(line[3:].strip())
513+
prompt_confirmation = True
514+
ai_generated = True
502515
elif self.ai:
503516
line = self.text_to_sql.generate(line)
504-
ai_generated = prompt_confirmation = True
517+
ai_generated = True
518+
prompt_confirmation = True
505519
else:
506520
line = [line]
507521
self.__start_time = time.time()
508522
for sql_statement in line:
509523
if (
510-
not self.yes
511-
and prompt_confirmation
524+
prompt_confirmation
525+
and not self.yes
512526
and not click.confirm("[Exc] - " + sql_statement)
513527
):
514528
continue
@@ -568,8 +582,13 @@ def stdout_data(
568582
table.add_column("Index", justify="center")
569583

570584
def add_headers(header_values: list[str]):
571-
for header in header_values:
572-
table.add_column(header)
585+
if data and len(header_values) == len(data[0]):
586+
for header in header_values:
587+
table.add_column(header)
588+
else:
589+
logging.debug(
590+
f"No data to be displayed or length of data and headers don't match."
591+
)
573592

574593
if headers:
575594
add_headers(headers)
@@ -586,7 +605,7 @@ def add_headers(header_values: list[str]):
586605
specific_column_names_string = re.findall(
587606
r"^select\s+([\w_,\s]+)\s+from.+", *re_args
588607
)
589-
if re.match(r"^select\s+\.*", *re_args):
608+
if re.match(r"^select\s+\*.*", *re_args):
590609
table_name = re.findall(r".+from\s+([\w_]+).*", *re_args)
591610
if table_name:
592611
tbl_name = table_name[0]
@@ -615,7 +634,7 @@ def add_headers(header_values: list[str]):
615634
else:
616635
for index, entry in enumerate(data):
617636
table.add_row(*[str(index)] + [str(token) for token in entry])
618-
rich.print(table)
637+
rich.print(table)
619638

620639
@staticmethod
621640
@click.command()

requirements.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)