Skip to content

Commit 12203a5

Browse files
committed
feat: Add AI-powered SQL generation feature
fix: Improve error handling for AI integration patch: Update dependencies and add python-tgpt
1 parent 053bb76 commit 12203a5

File tree

2 files changed

+147
-8
lines changed

2 files changed

+147
-8
lines changed

manager.py

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/python3
22
import os
3+
import re
34
import cmd
45
import sys
56
import time
@@ -12,8 +13,12 @@
1213
from pathlib import Path
1314
from colorama import Fore
1415
from functools import wraps
16+
17+
# Rich
1518
from rich.table import Table
1619
from rich.console import Console
20+
from rich.markdown import Markdown
21+
from rich.console import Console
1722

1823
# prompt toolkit
1924
from prompt_toolkit import PromptSession
@@ -73,14 +78,18 @@ def execute_sql_command(
7378
finally:
7479
cursor.close()
7580

76-
def tables(self):
81+
def tables(self, tables_only: bool = False):
7782
"""List tables available"""
7883
return self.execute_sql_command("PRAGMA table_list;")
7984

8085
def table_columns(self, table: str):
8186
"""List table columns and their metadata"""
8287
return self.execute_sql_command(f"PRAGMA table_info({table});")
8388

89+
def schema(self):
90+
"""Sqlite schema contents"""
91+
return self.execute_sql_command("SELECT * FROM sqlite_schema;")
92+
8493
def commit(self):
8594
"""Commit changes"""
8695
if self.db_connection and not self.auto_commit:
@@ -98,6 +107,95 @@ def __exit__(self) -> t.NoReturn:
98107
self.db_connection.close()
99108

100109

110+
class TextToSql:
111+
"""Generate SQL Statement based on given prompt"""
112+
113+
def __init__(self, db_manager: Sqlite3Manager):
114+
"""Initializes `TextToSql`"""
115+
try:
116+
import pytgpt.auto as auto
117+
except ImportError:
118+
raise Exception(
119+
"Looks like pytgpt isn't installed. Reistall it before using TextToSql - "
120+
'"pip install python-tgpt"'
121+
)
122+
self.ai = auto.AUTO(update_file=False)
123+
assert isinstance(
124+
db_manager, Sqlite3Manager
125+
), f"db_manager must be an instance of {Sqlite3Manager} not {type(db_manager)}"
126+
self.db_manager = db_manager
127+
self.sql_pattern = r"\{([\w\W]*)\}"
128+
129+
@property
130+
def context_prompt(self) -> str:
131+
_, table_schema = self.db_manager.execute_sql_command(
132+
"""SELECT tbl_name, sql FROM sqlite_schema WHERE type='table'
133+
AND NOT tbl_name LIKE '%sqlite%';
134+
"""
135+
)
136+
table_schema_text = "\n".join(
137+
[tbl_schema[0] + " - " + tbl_schema[1] for tbl_schema in table_schema]
138+
)
139+
prompt = (
140+
(
141+
"""You're going to act like TEXT to SQL translater.
142+
Action to be performed on the sqlite3 database will be provided and then
143+
you will generate a complete SQL statement for accomplishing the same.
144+
Enclose the sql statement in curly braces '{}'. DO NOT ADD ANY OTHER TEXT
145+
EXCEPT when seeking clarification or confirmation.
146+
147+
Given below are the database table names and the SQL statements used to create them:
148+
\n"""
149+
)
150+
+ "\n "
151+
+ table_schema_text
152+
+ (
153+
"""
154+
\n
155+
For example:
156+
User: List top 10 entries in the Linux table where distro contains letter 'a'LLM : {SELECT * FROM Linux WHERE distro LIKE '%a%';}
157+
158+
User : Remove entries from table Linux whose id is greater than 10.
159+
LLLM : {DELETE * FROM Linux WHERE id > 10;}
160+
161+
If the user's request IS UNDOUBTEDBLY INCOMPLETE, seek clarification.
162+
For example:
163+
User: Add column to Linux table.
164+
LLM: What kind of data will be stored in the column and suggest column name if possible?
165+
User: The column will be storing maintainance status of the linux distros.
166+
LLM: {ALTER TABLE Linux ADD COLUMN is_maintained BOOLEAN;}
167+
168+
If the user's request can be disastrous then seek clarification or confirmation accordingly.
169+
These actions might include DELETE, ALTER and DROP.
170+
"""
171+
)
172+
)
173+
174+
return prompt
175+
176+
def process_response(self, response: str) -> list[str]:
177+
"""Tries to extract the sql statement from ai response
178+
179+
Args:
180+
response (str): ai response
181+
"""
182+
if response.startswith("{") and not response.endswith("}"):
183+
response += "}"
184+
185+
sql_statements = re.findall(self.sql_pattern, response)
186+
if sql_statements:
187+
return [sql for sql in re.split(";", sql_statements[0]) if sql]
188+
else:
189+
Console().print(Markdown(response))
190+
return []
191+
192+
def generate(self, prompt: str):
193+
"""Main method"""
194+
self.ai.intro = self.context_prompt
195+
assert prompt, f"Prompt cannot be null!"
196+
return self.process_response(self.ai.chat(prompt))
197+
198+
101199
class HistoryCompletions(Completer):
102200
def __init__(self, session, disable_suggestions):
103201
self.session: PromptSession = session
@@ -130,6 +228,7 @@ def __init__(
130228
new_history_thread,
131229
json,
132230
color,
231+
ai,
133232
):
134233
super().__init__()
135234
self.__start_time = time.time()
@@ -146,6 +245,9 @@ def __init__(
146245
self.completer_session.completer = ThreadedCompleter(
147246
HistoryCompletions(self.completer_session, disable_suggestions)
148247
)
248+
self.ai = ai
249+
if self.ai:
250+
self.text_to_sql = TextToSql(self.db_manager)
149251

150252
@property
151253
def prompt(self):
@@ -282,6 +384,20 @@ def do_sys(self, line):
282384
"""
283385
os.system(line)
284386

387+
def do_sql(self, line):
388+
"""Execute sql statement"""
389+
self.default("/sql " + line)
390+
391+
def do_ai(self, line):
392+
"""Generate sql statements with AI and execute"""
393+
self.default("/ai " + line)
394+
395+
@cli_error_handler
396+
def do_schema(self, line):
397+
"""Show database schema"""
398+
success, tables = self.db_manager.schema()
399+
Commands.stdout_data(success, tables, json=self.json, color=self.color)
400+
285401
@cli_error_handler
286402
def do_tables(self, line):
287403
"""Show database tables"""
@@ -304,12 +420,20 @@ def default(self, line):
304420
"""Run sql statemnt against database"""
305421
if line.startswith("./"):
306422
self.do_sys(line[2:])
307-
423+
return
424+
elif line.startswith("/sql"):
425+
line = [line[4:]]
426+
elif line.startswith("/ai"):
427+
line = TextToSql(self.db_manager).generate(line[3:])
428+
elif self.ai:
429+
line = self.text_to_sql.generate(line)
308430
else:
309-
self.__start_time = time.time()
310-
success, response = self.db_manager.execute_sql_command(line)
431+
line = [line]
432+
self.__start_time = time.time()
433+
for sql_statement in line:
434+
success, response = self.db_manager.execute_sql_command(sql_statement)
311435
Commands.stdout_data(success, response, json=self.json, color=self.color)
312-
self.__end_time = time.time()
436+
self.__end_time = time.time()
313437

314438
def do_exit(self, line):
315439
"""Quit this program"""
@@ -388,12 +512,21 @@ def show_columns(database, table, json):
388512
"database", type=click.Path(exists=True, dir_okay=False, resolve_path=True)
389513
)
390514
@click.option("-s", "--sql", multiple=True, help="Sql statements", required=True)
515+
@click.option(
516+
"-i", "--ai", is_flag=True, help="Generate sql statements from prompt by AI"
517+
)
391518
@click.option("-j", "--json", is_flag=True, help="Stdout results in json format")
392519
@click.option("-q", "--quiet", is_flag=True, help="Do not stdout results")
393-
def execute(database, sql, json, quiet):
520+
def execute(database, sql, ai, json, quiet):
394521
"""Run sql statements against database [AUTO-COMMITS]"""
395522
db_manager = Sqlite3Manager(database, auto_commit=True)
396-
for sql_statement in sql:
523+
if ai:
524+
text_to_sql = TextToSql(db_manager)
525+
ai_gen_sql_statements = []
526+
for prompt in sql:
527+
ai_gen_sql_statements.extend(text_to_sql.generate(prompt))
528+
529+
for sql_statement in sql if not ai else ai_gen_sql_statements:
397530
success, tables = db_manager.execute_sql_command(sql_statement)
398531
if not quiet:
399532
Commands.stdout_data(success, tables, json=json)
@@ -409,6 +542,9 @@ def execute(database, sql, json, quiet):
409542
)
410543
@click.option("-j", "--json", help="Stdout results in json format", is_flag=True)
411544
@click.option("-a", "--auto-commit", is_flag=True, help="Enable auto-commit")
545+
@click.option(
546+
"-i", "--ai", is_flag=True, help="Generate sql statements from prompt by AI"
547+
)
412548
@click.option(
413549
"-C",
414550
"--disable-coloring",
@@ -429,6 +565,7 @@ def interactive(
429565
color,
430566
json,
431567
auto_commit,
568+
ai,
432569
disable_coloring,
433570
disable_suggestions,
434571
new_history_thread,
@@ -442,6 +579,7 @@ def interactive(
442579
new_history_thread=new_history_thread,
443580
json=json,
444581
color=color,
582+
ai=ai,
445583
)
446584
main.cmdloop()
447585

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
click==8.1.3
22
rich==13.3.4
33
colorama==0.4.6
4-
prompt-toolkit==3.0.48
4+
prompt-toolkit==3.0.48
5+
python-tgpt==0.7.7

0 commit comments

Comments
 (0)