Skip to content

Commit faeccb2

Browse files
authored
Merge pull request #1229 from dbcli/amjith/add-llm-support
Implement \llm command.
2 parents b3385a1 + 9aa1f50 commit faeccb2

File tree

12 files changed

+614
-56
lines changed

12 files changed

+614
-56
lines changed

mycli/main.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command
5252
from mycli.packages.parseutils import is_destructive, is_dropping_database
5353
from mycli.packages.prompt_utils import confirm, confirm_destructive_query
54-
from mycli.packages.special.favoritequeries import FavoriteQueries
5554
from mycli.packages.special.main import ArgType
5655
from mycli.packages.tabular_output import sql_format
5756
from mycli.packages.toolkit.history import FileHistoryWithTimestamp
@@ -128,8 +127,6 @@ def __init__(
128127
special.set_timing_enabled(c["main"].as_bool("timing"))
129128
self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0)
130129

131-
FavoriteQueries.instance = FavoriteQueries.from_config(self.config)
132-
133130
self.dsn_alias = None
134131
self.main_formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
135132
self.redirect_formatter = TabularOutputFormatter(format_name=c["main"].get("redirect_format", "csv"))
@@ -681,6 +678,47 @@ def get_continuation(width, *_):
681678
def show_suggestion_tip():
682679
return iterations < 2
683680

681+
def output_res(res, start):
682+
result_count = 0
683+
mutating = False
684+
for title, cur, headers, status in res:
685+
logger.debug("headers: %r", headers)
686+
logger.debug("rows: %r", cur)
687+
logger.debug("status: %r", status)
688+
threshold = 1000
689+
if is_select(status) and cur and cur.rowcount > threshold:
690+
self.echo(
691+
"The result set has more than {} rows.".format(threshold),
692+
fg="red",
693+
)
694+
if not confirm("Do you want to continue?"):
695+
self.echo("Aborted!", err=True, fg="red")
696+
break
697+
698+
if self.auto_vertical_output:
699+
max_width = self.prompt_app.output.get_size().columns
700+
else:
701+
max_width = None
702+
703+
formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width)
704+
705+
t = time() - start
706+
try:
707+
if result_count > 0:
708+
self.echo("")
709+
try:
710+
self.output(formatted, status)
711+
except KeyboardInterrupt:
712+
pass
713+
self.echo("Time: %0.03fs" % t)
714+
except KeyboardInterrupt:
715+
pass
716+
717+
start = time()
718+
result_count += 1
719+
mutating = mutating or is_mutating(status)
720+
return mutating
721+
684722
def one_iteration(text=None):
685723
if text is None:
686724
try:
@@ -707,6 +745,27 @@ def one_iteration(text=None):
707745
logger.error("traceback: %r", traceback.format_exc())
708746
self.echo(str(e), err=True, fg="red")
709747
return
748+
# LLM command support
749+
while special.is_llm_command(text):
750+
try:
751+
start = time()
752+
cur = sqlexecute.conn.cursor()
753+
context, sql, duration = special.handle_llm(text, cur)
754+
if context:
755+
click.echo("LLM Response:")
756+
click.echo(context)
757+
click.echo("---")
758+
click.echo(f"Time: {duration:.2f} seconds")
759+
text = self.prompt_app.prompt(default=sql)
760+
except KeyboardInterrupt:
761+
return
762+
except special.FinishIteration as e:
763+
return output_res(e.results, start) if e.results else None
764+
except RuntimeError as e:
765+
logger.error("sql: %r, error: %r", text, e)
766+
logger.error("traceback: %r", traceback.format_exc())
767+
self.echo(str(e), err=True, fg="red")
768+
return
710769

711770
if not text.strip():
712771
return

mycli/packages/completion_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def suggest_special(text: str) -> list[dict[str, Any]]:
107107
]
108108
elif cmd in ["\\.", "source"]:
109109
return [{"type": "file_name"}]
110+
if cmd in ["\\llm", "\\ai"]:
111+
return [{"type": "llm"}]
110112

111113
return [{"type": "keyword"}, {"type": "special"}]
112114

mycli/packages/special/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ def export(defn: Callable):
1515
from mycli.packages.special import (
1616
dbcommands, # noqa: E402 F401
1717
iocommands, # noqa: E402 F401
18+
llm, # noqa: E402 F401
1819
)

mycli/packages/special/iocommands.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Generator
1111

1212
import click
13+
from configobj import ConfigObj
1314
from pymysql.cursors import Cursor
1415
import pyperclip
1516
import sqlparse
@@ -36,6 +37,13 @@
3637
'stdout_mode': None,
3738
}
3839
delimiter_command = DelimiterCommand()
40+
favoritequeries = FavoriteQueries(ConfigObj())
41+
42+
43+
@export
44+
def set_favorite_queries(config):
45+
global favoritequeries
46+
favoritequeries = FavoriteQueries(config)
3947

4048

4149
@export
@@ -261,7 +269,7 @@ def execute_favorite_query(cur: Cursor, arg: str, **_) -> Generator[tuple, None,
261269
name, _separator, arg_str = arg.partition(" ")
262270
args = shlex.split(arg_str)
263271

264-
query = FavoriteQueries.instance.get(name)
272+
query = favoritequeries.get(name)
265273
if query is None:
266274
message = "No favorite query: %s" % (name)
267275
yield (None, None, None, message)
@@ -286,10 +294,10 @@ def list_favorite_queries() -> list[tuple]:
286294
Returns (title, rows, headers, status)"""
287295

288296
headers = ["Name", "Query"]
289-
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
297+
rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]
290298

291299
if not rows:
292-
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
300+
status = "\nNo favorite queries found." + favoritequeries.usage
293301
else:
294302
status = ""
295303
return [("", rows, headers, status)]
@@ -316,7 +324,7 @@ def save_favorite_query(arg: str, **_) -> list[tuple]:
316324
"""Save a new favorite query.
317325
Returns (title, rows, headers, status)"""
318326

319-
usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
327+
usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
320328
if not arg:
321329
return [(None, None, None, usage)]
322330

@@ -326,18 +334,18 @@ def save_favorite_query(arg: str, **_) -> list[tuple]:
326334
if (not name) or (not query):
327335
return [(None, None, None, usage + "Err: Both name and query are required.")]
328336

329-
FavoriteQueries.instance.save(name, query)
337+
favoritequeries.save(name, query)
330338
return [(None, None, None, "Saved.")]
331339

332340

333341
@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
334342
def delete_favorite_query(arg: str, **_) -> list[tuple]:
335343
"""Delete an existing favorite query."""
336-
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
344+
usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
337345
if not arg:
338346
return [(None, None, None, usage)]
339347

340-
status = FavoriteQueries.instance.delete(arg)
348+
status = favoritequeries.delete(arg)
341349

342350
return [(None, None, None, status)]
343351

0 commit comments

Comments
 (0)