diff --git a/pgspecial/__init__.py b/pgspecial/__init__.py index 340b02c..1700cd7 100644 --- a/pgspecial/__init__.py +++ b/pgspecial/__init__.py @@ -11,4 +11,4 @@ def export(defn): return defn -from . import dbcommands, iocommands # noqa +from . import dbcommands, iocommands, llm # noqa diff --git a/pgspecial/llm.py b/pgspecial/llm.py new file mode 100644 index 0000000..1f984ca --- /dev/null +++ b/pgspecial/llm.py @@ -0,0 +1,316 @@ +import contextlib +import io +import logging +import os +import re +from runpy import run_module +import shlex +import sys +from time import time +from typing import Optional, Tuple +from . import export + +import click + +try: + import llm # type: ignore + from llm.cli import cli # type: ignore +except Exception: # pragma: no cover - llm may not be installed in all envs + llm = None + cli = None + +from pgspecial.main import parse_special_command, Verbosity + +log = logging.getLogger(__name__) + + +def _safe_models(): # pragma: no cover - used when llm is installed + try: + return {x.model_id: None for x in llm.get_models()} if llm else {} + except Exception: + return {} + + +LLM_CLI_COMMANDS = list(cli.commands.keys()) if cli else [] +MODELS = _safe_models() +LLM_TEMPLATE_NAME = "pgspecial-llm-template" + + +def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): + original_exe = sys.executable + original_args = sys.argv + try: + sys.argv = [cmd] + list(args) + code = 0 + if capture_output: + buffer = io.StringIO() + redirect = contextlib.ExitStack() + redirect.enter_context(contextlib.redirect_stdout(buffer)) + redirect.enter_context(contextlib.redirect_stderr(buffer)) + else: + redirect = contextlib.nullcontext() + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + code = e.code + if code != 0 and raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") + except Exception as e: + code = 1 + if raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed: {e}") + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + +def build_command_tree(cmd): # pragma: no cover - not used in tests directly + tree = {} + if cmd and isinstance(getattr(cmd, "commands", None), dict): + for name, subcmd in cmd.commands.items(): + if getattr(cmd, "name", None) == "models" and name == "default": + tree[name] = MODELS + else: + tree[name] = build_command_tree(subcmd) + else: + tree = None + return tree + + +COMMAND_TREE = build_command_tree(cli) if cli else {} + + +def get_completions(tokens, tree=COMMAND_TREE): # pragma: no cover - helper + for token in tokens: + if token.startswith("-"): + continue + if tree and token in tree: + tree = tree[token] + else: + return [] + return list(tree.keys()) if tree else [] + + +@export +class FinishIteration(Exception): + def __init__(self, results=None): + self.results = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm 'Most visited urls?' + +# List available models +> \\llm models +> gpt-4o +> gpt-3.5-turbo + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai + +# Install a model plugin +> \\llm install llm-ollama +> llm-ollama installed. + +# Plugins directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" + +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" + +PROMPT = """ +You are a helpful assistant who is a PostgreSQL expert. You are embedded in a +psql-like cli tool called pgcli. + +Answer this question: + +$question + +Use the following context if it is relevant to answering the question. If the +question is not about the current database then ignore the context. + +You are connected to a PostgreSQL database with the following schema: + +$db_schema + +Here is a sample row of data from each table: + +$sample_data + +If the answer can be found using a SQL query, include a sql query in a code +fence such as this one: + +```sql +SELECT count(*) FROM table_name; +``` +Keep your explanation concise and focused on the question asked. +""" + + +def ensure_pgspecial_template(replace=False): + if not replace: + code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) + if code == 0: + return + run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) + return + + +@export +def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: + _, verbosity, arg = parse_special_command(text) + if not arg.strip(): + output = USAGE + raise FinishIteration(output) + + parts = shlex.split(arg) + restart = False + if "-c" in parts: + capture_output = True + use_context = False + elif "prompt" in parts: + capture_output = True + use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True + elif parts and parts[0] in LLM_CLI_COMMANDS: + capture_output = False + use_context = False + elif parts and parts[0] == "--help": + capture_output = False + use_context = False + else: + capture_output = True + use_context = True + + if not use_context: + args = parts + if capture_output: + click.echo("Calling llm command") + start = time() + _, result = run_external_cmd("llm", *args, capture_output=capture_output) + end = time() + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = result + raise FinishIteration(output) + return (result if verbosity == Verbosity.SUCCINCT else "", sql, end - start) + else: + run_external_cmd("llm", *args, restart_cli=restart) + raise FinishIteration(None) + + try: + ensure_pgspecial_template() + start = time() + context, sql = sql_using_llm(cur=cur, question=arg) + end = time() + if verbosity == Verbosity.SUCCINCT: + context = "" + return (context, sql, end - start) + except Exception as e: + raise RuntimeError(e) + + +@export +def is_llm_command(command) -> bool: + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai") + + +def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: + if cur is None: + raise RuntimeError("Connect to a database and try again.") + + schema_sql = """ + SELECT + table_schema, + table_name, + string_agg(column_name || ' ' || data_type, ', ' ORDER BY ordinal_position) AS cols + FROM information_schema.columns + WHERE table_schema NOT IN ('pg_catalog', 'information_schema') + GROUP BY table_schema, table_name + ORDER BY table_schema, table_name + """ + tables_sql = """ + SELECT table_schema, table_name + FROM information_schema.tables + WHERE table_schema NOT IN ('pg_catalog', 'information_schema') + AND table_type IN ('BASE TABLE', 'VIEW') + ORDER BY table_schema, table_name + """ + sample_row_tmpl = 'SELECT * FROM "{schema}"."{table}" LIMIT 1' + + click.echo("Preparing schema information to feed the llm") + cur.execute(schema_sql) + db_schema = [] + for row in cur.fetchall(): + # Support both tuple results and dict-like rows + if isinstance(row, (list, tuple)): + schema, table, cols = row + else: + schema, table, cols = row["table_schema"], row["table_name"], row["cols"] + db_schema.append(f"{schema}.{table}({cols})") + db_schema = "\n".join(db_schema) + + cur.execute(tables_sql) + sample_data = {} + for row in cur.fetchall(): + if isinstance(row, (list, tuple)): + schema, table = row + else: + schema, table = row["table_schema"], row["table_name"] + try: + cur.execute(sample_row_tmpl.format(schema=schema, table=table)) + except Exception: + continue + cols = [desc[0] for desc in getattr(cur, "description", [])] + one = getattr(cur, "fetchone", lambda: None)() + if not one: + continue + sample_data[f"{schema}.{table}"] = list(zip(cols, one)) + + args = [ + "--template", + LLM_TEMPLATE_NAME, + "--param", + "db_schema", + db_schema, + "--param", + "sample_data", + sample_data, + "--param", + "question", + question, + " ", + ] + click.echo("Invoking llm command with schema information") + _, result = run_external_cmd("llm", *args, capture_output=True) + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + return (result, sql) diff --git a/pgspecial/main.py b/pgspecial/main.py index ab13374..9250d2e 100644 --- a/pgspecial/main.py +++ b/pgspecial/main.py @@ -3,8 +3,16 @@ import logging from collections import namedtuple -from . import export from .help.commands import helpcommands +from . import export +from enum import Enum + + +class Verbosity(Enum): + SUCCINCT = "succinct" + NORMAL = "normal" + VERBOSE = "verbose" + log = logging.getLogger(__name__) @@ -96,7 +104,7 @@ def register(self, *args, **kwargs): def execute(self, cur, sql): commands = self.commands - command, verbose, pattern = parse_special_command(sql) + command, verbosity, pattern = parse_special_command(sql) if (command not in commands) and (command.lower() not in commands): raise CommandNotFound @@ -111,7 +119,8 @@ def execute(self, cur, sql): if special_cmd.arg_type == NO_QUERY: return special_cmd.handler() elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose) + # Keep existing handlers working: convert Verbosity -> bool + return special_cmd.handler(cur=cur, pattern=pattern, verbose=(verbosity == Verbosity.VERBOSE)) elif special_cmd.arg_type == RAW_QUERY: return special_cmd.handler(cur=cur, query=sql) @@ -225,10 +234,14 @@ def content_exceeds_width(row, width): @export def parse_special_command(sql): command, _, arg = sql.partition(" ") - verbose = "+" in command - - command = command.strip().replace("+", "") - return (command, verbose, arg.strip()) + verbosity = Verbosity.NORMAL + if "+" in command: + verbosity = Verbosity.VERBOSE + elif "-" in command: + verbosity = Verbosity.SUCCINCT + + command = command.strip().strip("+-") + return (command, verbosity, arg.strip()) def show_extra_help_command(command, syntax, description): diff --git a/tests/test_llm_special.py b/tests/test_llm_special.py new file mode 100644 index 0000000..60b1156 --- /dev/null +++ b/tests/test_llm_special.py @@ -0,0 +1,188 @@ +from unittest.mock import patch + +import pytest + +from pgspecial.llm import ( + USAGE, + FinishIteration, + handle_llm, + is_llm_command, + sql_using_llm, +) + + +# Override executor fixture to avoid real DB connections during llm tests +@pytest.fixture +def executor(): + return None + + +@patch("pgspecial.llm.llm") +def test_llm_command_without_args(mock_llm, executor): + r""" + Invoking \llm without any arguments should print the usage and raise FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + # Should return usage message when no args provided + assert exc_info.value.args[0] == USAGE + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.run_external_cmd") +def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): + # Suppose the LLM returns some text without fenced SQL + mock_run_cmd.return_value = (0, "Hello, no SQL today.") + test_text = r"\llm -c 'Something?'" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + # Expect raw output when no SQL fence found + assert exc_info.value.args[0] == "Hello, no SQL today." + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.run_external_cmd") +def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor): + # Return text containing a fenced SQL block + sql_text = "SELECT * FROM users;" + fenced = f"Here you go:\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + test_text = r"\llm -c 'Rewrite SQL'" + result, sql, duration = handle_llm(test_text, executor) + # Without verbose, result is empty, sql extracted + assert sql == sql_text + assert result == "" + assert isinstance(duration, float) + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.run_external_cmd") +def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm --help" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False) + assert exc_info.value.args[0] is None + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.run_external_cmd") +def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm install openai" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) + assert exc_info.value.args[0] is None + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.ensure_pgspecial_template") +@patch("pgspecial.llm.sql_using_llm") +def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm prompt 'question' should use template and call sql_using_llm + """ + mock_sql_using_llm.return_value = ("CTX", "SELECT 1;") + test_text = r"\llm prompt 'Test?'" + context, sql, duration = handle_llm(test_text, executor) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX" + assert sql == "SELECT 1;" + assert isinstance(duration, float) + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.ensure_pgspecial_template") +@patch("pgspecial.llm.sql_using_llm") +def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm 'question' treats as prompt and returns SQL + """ + mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;") + test_text = r"\llm 'Top 10?'" + context, sql, duration = handle_llm(test_text, executor) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX2" + assert sql == "SELECT 2;" + assert isinstance(duration, float) + + +@patch("pgspecial.llm.llm") +@patch("pgspecial.llm.ensure_pgspecial_template") +@patch("pgspecial.llm.sql_using_llm") +def test_llm_command_question_succinct(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm- returns succinct (empty) context and SQL + """ + mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;") + test_text = r"\llm- 'Succinct?'" + context, sql, duration = handle_llm(test_text, executor) + assert context == "" + assert sql == "SELECT 42;" + assert isinstance(duration, float) + + +def test_is_llm_command(): + # Valid llm command variants + for cmd in ["\\llm", "\\ai"]: + assert is_llm_command(cmd + " 'x'") + # Invalid commands + assert not is_llm_command("select * from table;") + + +def test_sql_using_llm_no_connection(): + # Should error if no database cursor provided + with pytest.raises(RuntimeError) as exc_info: + sql_using_llm(None, question="test") + assert "Connect to a database" in str(exc_info.value) + + +@patch("pgspecial.llm.run_external_cmd") +def test_sql_using_llm_success(mock_run_cmd): + # Dummy cursor simulating database schema and sample data for PostgreSQL + class DummyCursor: + def __init__(self): + self._last = [] + + def execute(self, query): + q = " ".join(query.split()).lower() + if "from information_schema.columns" in q: + self._last = [ + ("public", "table1", "col1 integer, col2 text"), + ("public", "table2", "colA character varying"), + ] + elif "from information_schema.tables" in q: + self._last = [("public", "table1"), ("public", "table2")] + elif q.startswith('select * from "public"."table'): + self.description = [("col1", None), ("col2", None)] + self._row = (1, "abc") + + def fetchall(self): + return getattr(self, "_last", []) + + def fetchone(self): + return getattr(self, "_row", None) + + dummy_cur = DummyCursor() + # Simulate llm CLI returning a fenced SQL result + sql_text = "SELECT 1, 'abc';" + fenced = f"Note\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + result, sql = sql_using_llm(dummy_cur, question="dummy") + assert result == fenced + assert sql == sql_text + + +# Test handle_llm supports alias prefixes without args +@pytest.mark.parametrize("prefix", [r"\\llm", r"\\ai"]) +def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch): + from pgspecial import llm as llm_module + + monkeypatch.setattr(llm_module, "llm", object()) + with pytest.raises(FinishIteration) as exc_info: + handle_llm(prefix, executor) + assert exc_info.value.args[0] == USAGE