Skip to content

Commit 0c45d13

Browse files
authored
Merge pull request #1387 from dbcli/RW/refine-llm-py-typehinting
Refine typehints in `special/llm.py`
2 parents 0d4d631 + e687b6b commit 0c45d13

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Internal
1717
* Enable flake8-bugbear lint rules.
1818
* Fix flaky editor-command tests in CI.
1919
* Require release format of `changelog.md` when making a release.
20+
* Improve type annotations on LLM driver.
2021

2122

2223
1.40.0 (2025/10/14)

mycli/packages/special/llm.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import shlex
99
import sys
1010
from time import time
11-
from typing import Optional, Tuple
11+
from typing import Any
1212

1313
import click
1414

@@ -30,6 +30,7 @@
3030
LLM_CLI_IMPORTED = False
3131
except ImportError:
3232
LLM_CLI_IMPORTED = False
33+
from pymysql.cursors import Cursor
3334

3435
from mycli.packages.special.main import Verbosity, parse_special_command
3536

@@ -38,15 +39,22 @@
3839
LLM_TEMPLATE_NAME = "mycli-llm-template"
3940

4041

41-
def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True):
42+
def run_external_cmd(
43+
cmd: str,
44+
*args,
45+
capture_output=False,
46+
restart_cli=False,
47+
raise_exception=True,
48+
) -> tuple[int, str]:
4249
original_exe = sys.executable
4350
original_args = sys.argv
4451
try:
4552
sys.argv = [cmd] + list(args)
4653
code = 0
4754
if capture_output:
4855
buffer = io.StringIO()
49-
redirect = contextlib.ExitStack()
56+
redirect: contextlib.ExitStack[bool | None] | contextlib.nullcontext[None] = contextlib.ExitStack()
57+
assert isinstance(redirect, contextlib.ExitStack)
5058
redirect.enter_context(contextlib.redirect_stdout(buffer))
5159
redirect.enter_context(contextlib.redirect_stderr(buffer))
5260
else:
@@ -55,7 +63,7 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_
5563
try:
5664
run_module(cmd, run_name="__main__")
5765
except SystemExit as e:
58-
code = e.code
66+
code = int(e.code or 0)
5967
if code != 0 and raise_exception:
6068
if capture_output:
6169
raise RuntimeError(buffer.getvalue()) from e
@@ -76,24 +84,33 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_
7684
sys.argv = original_args
7785

7886

79-
def build_command_tree(cmd):
80-
tree = {}
87+
def _build_command_tree(cmd) -> dict[str, Any] | None:
88+
tree: dict[str, Any] | None = {}
89+
assert isinstance(tree, dict)
8190
if isinstance(cmd, click.Group):
8291
for name, subcmd in cmd.commands.items():
8392
if cmd.name == "models" and name == "default":
8493
tree[name] = {x.model_id: None for x in llm.get_models()}
8594
else:
86-
tree[name] = build_command_tree(subcmd)
95+
tree[name] = _build_command_tree(subcmd)
8796
else:
8897
tree = None
8998
return tree
9099

91100

101+
def build_command_tree(cmd) -> dict[str, Any]:
102+
return _build_command_tree(cmd) or {}
103+
104+
92105
# Generate the command tree for autocompletion
93106
COMMAND_TREE = build_command_tree(cli) if LLM_CLI_IMPORTED is True else {}
94107

95108

96-
def get_completions(tokens, tree=COMMAND_TREE):
109+
def get_completions(
110+
tokens: list[str],
111+
tree: dict[str, Any] | None = None,
112+
) -> list[str]:
113+
tree = tree or COMMAND_TREE
97114
for token in tokens:
98115
if token.startswith("-"):
99116
continue
@@ -182,21 +199,20 @@ def __init__(self, results=None):
182199
"""
183200

184201

185-
def ensure_mycli_template(replace=False):
202+
def ensure_mycli_template(replace: bool = False) -> None:
186203
if not replace:
187204
code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False)
188205
if code == 0:
189206
return
190207
run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME)
191-
return
192208

193209

194210
@functools.cache
195211
def cli_commands() -> list[str]:
196212
return list(cli.commands.keys())
197213

198214

199-
def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
215+
def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
200216
_, verbosity, arg = parse_special_command(text)
201217
if not LLM_IMPORTED:
202218
output = [(None, None, None, NEED_DEPENDENCIES)]
@@ -254,12 +270,15 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
254270
raise RuntimeError(e) from e
255271

256272

257-
def is_llm_command(command) -> bool:
273+
def is_llm_command(command: str) -> bool:
258274
cmd, _, _ = parse_special_command(command)
259275
return cmd in ("\\llm", "\\ai")
260276

261277

262-
def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]:
278+
def sql_using_llm(
279+
cur: Cursor | None,
280+
question: str | None = None,
281+
) -> tuple[str, str | None]:
263282
if cur is None:
264283
raise RuntimeError("Connect to a database and try again.")
265284
schema_query = """

0 commit comments

Comments
 (0)