diff --git a/changelog.md b/changelog.md index d6df9e6e..9851dfce 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +-------- +* Limit size of LLM prompts and cache LLM prompt data. + + Internal -------- * Include LLM dependencies in tox configuration. diff --git a/mycli/main.py b/mycli/main.py index 2b41908f..41f11453 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -795,9 +795,10 @@ def one_iteration(text: str | None = None) -> None: while special.is_llm_command(text): start = time() try: + assert isinstance(self.sqlexecute, SQLExecute) assert sqlexecute.conn is not None cur = sqlexecute.conn.cursor() - context, sql, duration = special.handle_llm(text, cur) + context, sql, duration = special.handle_llm(text, cur, sqlexecute.dbname or '') if context: click.echo("LLM Response:") click.echo(context) diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py index d19b8c41..309a7057 100644 --- a/mycli/packages/special/llm.py +++ b/mycli/packages/special/llm.py @@ -212,7 +212,7 @@ def cli_commands() -> list[str]: return list(cli.commands.keys()) -def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]: +def handle_llm(text: str, cur: Cursor, dbname: str) -> tuple[str, str | None, float]: _, verbosity, arg = parse_special_command(text) if not LLM_IMPORTED: output = [(None, None, None, NEED_DEPENDENCIES)] @@ -261,7 +261,7 @@ def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]: try: ensure_mycli_template() start = time() - context, sql = sql_using_llm(cur=cur, question=arg) + context, sql = sql_using_llm(cur=cur, question=arg, dbname=dbname) end = time() if verbosity == Verbosity.SUCCINCT: context = "" @@ -275,45 +275,81 @@ def is_llm_command(command: str) -> bool: return cmd in ("\\llm", "\\ai") -def sql_using_llm( - cur: Cursor | None, - question: str | None = None, -) -> tuple[str, str | None]: - if cur is None: - raise RuntimeError("Connect to a database and try again.") - schema_query = """ - SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') +def truncate_list_elements(row: list) -> list: + target_size = 100000 + width = 1024 + while width >= 0: + truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row] + if sum(sys.getsizeof(x) for x in truncated_row) <= target_size: + break + width -= 100 + return truncated_row + + +def truncate_table_lines(table: list[str]) -> list[str]: + target_size = 100000 + truncated_table = [] + running_sum = 0 + while table and running_sum <= target_size: + line = table.pop(0) + running_sum += sys.getsizeof(line) + truncated_table.append(line) + return truncated_table + + +@functools.cache +def get_schema(cur: Cursor, dbname: str) -> str: + click.echo("Preparing schema information to feed the LLM") + schema_query = f""" + SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS schema FROM information_schema.columns - WHERE table_schema = DATABASE() + WHERE table_schema = '{dbname}' GROUP BY table_name ORDER BY table_name """ - tables_query = "SHOW TABLES" - sample_row_query = "SELECT * FROM `{table}` LIMIT 1" - click.echo("Preparing schema information to feed the llm") cur.execute(schema_query) - db_schema = "\n".join([row[0] for (row,) in cur.fetchall()]) + db_schema = [row[0] for (row,) in cur.fetchall()] + return '\n'.join(truncate_table_lines(db_schema)) + + +@functools.cache +def get_sample_data(cur: Cursor, dbname: str) -> dict[str, Any]: + click.echo("Preparing sample data to feed the LLM") + tables_query = "SHOW TABLES" + sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1" cur.execute(tables_query) sample_data = {} for (table_name,) in cur.fetchall(): try: - cur.execute(sample_row_query.format(table=table_name)) + cur.execute(sample_row_query.format(dbname=dbname, table=table_name)) except Exception: continue cols = [desc[0] for desc in cur.description] row = cur.fetchone() if row is None: continue - sample_data[table_name] = list(zip(cols, row)) + sample_data[table_name] = list(zip(cols, truncate_list_elements(list(row)))) + return sample_data + + +def sql_using_llm( + cur: Cursor | None, + question: str | None, + dbname: str = '', +) -> tuple[str, str | None]: + if cur is None: + raise RuntimeError("Connect to a database and try again.") + if dbname == '': + raise RuntimeError("Choose a schema and try again.") args = [ "--template", LLM_TEMPLATE_NAME, "--param", "db_schema", - db_schema, + get_schema(cur, dbname), "--param", "sample_data", - sample_data, + get_sample_data(cur, dbname), "--param", "question", question, diff --git a/test/test_llm_special.py b/test/test_llm_special.py index a7fa578a..23b88644 100644 --- a/test/test_llm_special.py +++ b/test/test_llm_special.py @@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor): assert mock_llm is not None test_text = r"\llm" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql') # Should return usage message when no args provided assert exc_info.value.args[0] == [(None, None, None, USAGE)] @@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): mock_run_cmd.return_value = (0, "Hello, no SQL today.") test_text = r"\llm -c 'Something?'" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql') # Expect raw output when no SQL fence found assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")] @@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor fenced = f"Here you go:\n```sql\n{sql_text}\n```" mock_run_cmd.return_value = (0, fenced) test_text = r"\llm -c 'Rewrite SQL'" - result, sql, duration = handle_llm(test_text, executor) + result, sql, duration = handle_llm(test_text, executor, 'mysql') # Without verbose, result is empty, sql extracted assert sql == sql_text assert result == "" @@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): # 'models' is a known subcommand test_text = r"\llm models" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql') mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) assert exc_info.value.args[0] is None @@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): test_text = r"\llm --help" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql') mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False) assert exc_info.value.args[0] is None @@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): test_text = r"\llm install openai" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql') mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) assert exc_info.value.args[0] is None @@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_ """ mock_sql_using_llm.return_value = ("CTX", "SELECT 1;") test_text = r"\llm prompt 'Test?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql') mock_ensure_template.assert_called_once() mock_sql_using_llm.assert_called() assert context == "CTX" @@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ """ mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;") test_text = r"\llm 'Top 10?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql') mock_ensure_template.assert_called_once() mock_sql_using_llm.assert_called() assert context == "CTX2" @@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, """ mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;") test_text = r"\llm- 'Succinct?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql') assert context == "" assert sql == "SELECT 42;" assert isinstance(duration, float) @@ -181,7 +181,7 @@ def fetchone(self): sql_text = "SELECT 1, 'abc';" fenced = f"Note\n```sql\n{sql_text}\n```" mock_run_cmd.return_value = (0, fenced) - result, sql = sql_using_llm(dummy_cur, question="dummy") + result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql') assert result == fenced assert sql == sql_text @@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch): monkeypatch.setattr(llm_module, "llm", object()) with pytest.raises(FinishIteration) as exc_info: - handle_llm(prefix, executor) + handle_llm(prefix, executor, 'mysql') assert exc_info.value.args[0] == [(None, None, None, USAGE)]