diff --git a/pyproject.toml b/pyproject.toml index 8a763ce..b41b59d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,8 +94,33 @@ log_level = "DEBUG" line-length = 79 [tool.ruff.lint] -select = ["D", "E", "F", "I", "N", "PL", "RUF", "SIM"] -ignore = ["D102", "D103", "D105", "D107", "D415", "PLR2004"] +select = [ + "A", + "ANN", + "ARG", + "D", + "E", + "ERA", + "F", + "I", + "INP", + "LOG", + "N", + "PL", + "RUF", + "SIM", + "SLF", + "T20", + "TD", + "UP", + "W" +] +ignore = [ + "ANN003", "ANN401", + "D102", "D103", "D105", "D107", "D415", + "PLR2004", + "TD002", "TD003", +] [tool.ruff.lint.isort] force-sort-within-sections = true @@ -105,4 +130,5 @@ lines-after-imports = 2 convention = "google" [tool.ruff.lint.per-file-ignores] -"tests/**" = ["D"] +"__main__.py" = ["T20"] +"tests/**" = ["ANN", "D", "SLF"] diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 29b56bf..167c0c8 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -2,18 +2,18 @@ from __future__ import annotations +from collections.abc import Sequence import importlib.metadata import logging import optparse from pathlib import Path, PurePosixPath import sys -from typing import Sequence from .bots import load_bot from .common import PROGRAM, Config, UnreachableError, ensure_state_home from .drafter import Drafter from .editor import open_editor -from .prompt import Template, TemplatedPrompt, templates_table +from .prompt import Template, TemplatedPrompt, find_template, templates_table from .store import Store from .toolbox import ToolVisitor @@ -41,7 +41,13 @@ def new_parser() -> optparse.OptionParser: ) def add_command(name: str, short: str | None = None, **kwargs) -> None: - def callback(_option, _opt, _value, parser) -> None: + def callback( + _option: object, + _opt: object, + _value: object, + parser: optparse.OptionParser, + ) -> None: + assert parser.values parser.values.command = name parser.add_option( @@ -222,11 +228,11 @@ def main() -> None: # noqa: PLR0912 PLR0915 if table: print(table.to_json() if opts.json else table) elif command == "show-prompts": - raise NotImplementedError() # TODO + raise NotImplementedError() # TODO: Implement elif command == "show-templates": if args: name = args[0] - tpl = Template.find(name) + tpl = find_template(name) if opts.edit: if tpl: edit(path=tpl.local_path(), text=tpl.source) diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index 2da0b81..e1da2ff 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -29,10 +29,10 @@ class Action: request_count: int | None = None token_count: int | None = None - def increment_request_count(self, n=1, init=False) -> None: + def increment_request_count(self, n: int = 1, init: bool = False) -> None: self._increment("request_count", n, init) - def increment_token_count(self, n, init=False) -> None: + def increment_token_count(self, n: int, init: bool = False) -> None: self._increment("token_count", n, init) def _increment(self, attr: str, count: int, init: bool) -> None: @@ -48,7 +48,7 @@ class Bot: """Code assistant bot""" @classmethod - def state_folder_path(cls, ensure_exists=False) -> Path: + def state_folder_path(cls, ensure_exists: bool = False) -> Path: """Returns a path unique to this bot class The path can be used to store data specific to this bot implementation. diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index b69b413..b0c9919 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -12,11 +12,12 @@ * https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py """ +from collections.abc import Mapping, Sequence import json import logging import os from pathlib import PurePosixPath -from typing import Any, Mapping, Self, Sequence, TypedDict, override +from typing import Any, Self, TypedDict, override import openai @@ -61,7 +62,7 @@ def _param( name: str, description: str, inputs: Mapping[str, Any] | None = None, - required_inputs: Sequence[str] | None = None, + _required_inputs: Sequence[str] | None = None, ) -> openai.types.beta.FunctionToolParam: param: openai.types.beta.FunctionToolParam = { "type": "function", @@ -225,10 +226,10 @@ def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str: return f"`{path}` does not exist." return f"The contents of `{path}` are:\n\n```\n{contents}\n```\n" - def _on_write_file(self, path: PurePosixPath) -> None: + def _on_write_file(self, _path: PurePosixPath) -> None: return None - def _on_delete_file(self, path: PurePosixPath) -> None: + def _on_delete_file(self, _path: PurePosixPath) -> None: return None def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: @@ -316,7 +317,7 @@ def on_run_step_done( else: _logger.warning("Missing usage in threads run step") - def _handle_action(self, run_id: str, data: Any) -> None: + def _handle_action(self, _run_id: str, data: Any) -> None: tool_outputs = list[Any]() for tool in data.required_action.submit_tool_outputs.tool_calls: handler = _ThreadToolHandler(self._toolbox, tool.id) @@ -347,15 +348,15 @@ def _wrap(self, output: str) -> _ToolOutput: return _ToolOutput(tool_call_id=self._call_id, output=output) def _on_read_file( - self, path: PurePosixPath, contents: str | None + self, _path: PurePosixPath, contents: str | None ) -> _ToolOutput: return self._wrap(contents or "") - def _on_write_file(self, path: PurePosixPath) -> _ToolOutput: + def _on_write_file(self, _path: PurePosixPath) -> _ToolOutput: return self._wrap("OK") - def _on_delete_file(self, path: PurePosixPath) -> _ToolOutput: + def _on_delete_file(self, _path: PurePosixPath) -> _ToolOutput: return self._wrap("OK") def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput: - return self._wrap("\n".join((str(p) for p in paths))) + return self._wrap("\n".join(str(p) for p in paths)) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index c3673a1..1535a8d 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence import dataclasses import itertools import logging @@ -11,7 +12,7 @@ import string import textwrap import tomllib -from typing import Any, ClassVar, Mapping, Self, Sequence, Type +from typing import Any, ClassVar, Self import prettytable import xdg_base_dirs @@ -84,7 +85,7 @@ class UnreachableError(RuntimeError): """Indicates unreachable code was unexpectedly executed""" -def reindent(s: str, width=0) -> str: +def reindent(s: str, width: int = 0) -> str: """Reindents text by dedenting and optionally wrapping paragraphs""" paragraphs = ( " ".join(textwrap.dedent("\n".join(g)).splitlines()) @@ -96,7 +97,7 @@ def reindent(s: str, width=0) -> str: ) -def qualified_class_name(cls: Type) -> str: +def qualified_class_name(cls: type) -> str: name = cls.__qualname__ return f"{cls.__module__}.{name}" if cls.__module__ else name diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 9da8b67..20b63b3 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses from datetime import datetime import json @@ -10,9 +11,9 @@ import os.path as osp from pathlib import PurePosixPath import re +from re import Match import textwrap import time -from typing import Callable, Match, Sequence import git @@ -53,7 +54,7 @@ def active(cls, repo: git.Repo, name: str | None = None) -> _Branch | None: return _Branch(match[1]) @staticmethod - def new_suffix(): + def new_suffix() -> str: return random_id(9) @@ -85,7 +86,7 @@ def generate_draft( # noqa: PLR0913 timeout: float | None = None, ) -> str: if timeout is not None: - raise NotImplementedError() # TODO + raise NotImplementedError() # TODO: Implement if self._repo.is_dirty(working_tree=False): if not reset: @@ -174,7 +175,9 @@ def generate_draft( # noqa: PLR0913 _logger.info("Completed generation for %s.", branch) return str(branch) - def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str: + def exit_draft( + self, *, revert: bool, clean: bool = False, delete: bool = False + ) -> str: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") @@ -312,7 +315,7 @@ def _untracked(self) -> frozenset[str]: text = self._repo.git.ls_files(exclude_standard=True, others=True) return frozenset(text.splitlines()) - def _delta(self, spec) -> _Delta: + def _delta(self, spec: str) -> _Delta: changed = list[str]() deleted = list[str]() for line in self._repo.git.diff(spec, name_status=True).splitlines(): diff --git a/src/git_draft/editor.py b/src/git_draft/editor.py index b605835..6927677 100644 --- a/src/git_draft/editor.py +++ b/src/git_draft/editor.py @@ -22,11 +22,16 @@ def _guess_editor_binpath() -> str: return "" -def _get_tty_filename(): +def _get_tty_filename() -> str: return "CON:" if sys.platform == "win32" else "/dev/tty" -def open_editor(text="", path: Path | None = None, *, _open_tty=open) -> str: +def open_editor( + text: str = "", + path: Path | None = None, + *, + _open_tty=open, # noqa +) -> str: """Open an editor to edit a file and return its contents The method returns once the editor is closed. It respects the `$EDITOR` @@ -46,7 +51,7 @@ def edit(path: str) -> str: proc = subprocess.Popen([binpath, path], close_fds=True, stdout=stdout) proc.communicate() - with open(path, mode="r") as reader: + with open(path) as reader: return reader.read() if path: diff --git a/src/git_draft/prompt.py b/src/git_draft/prompt.py index fb3ed03..3e4c1fc 100644 --- a/src/git_draft/prompt.py +++ b/src/git_draft/prompt.py @@ -2,12 +2,13 @@ from __future__ import annotations +from collections.abc import Mapping import dataclasses import enum import itertools import os from pathlib import Path -from typing import Mapping, Self +from typing import Self import jinja2 import jinja2.meta @@ -72,7 +73,7 @@ def templates_table() -> Table: for rel_path in env.list_templates(extensions=[_extension]): if any(p.startswith(".") for p in rel_path.split(os.sep)): continue - tpl = Template._load(rel_path, env) + tpl = _load_template(rel_path, env) local = "y" if tpl.is_local() else "n" table.data.add_row([tpl.name, local, tpl.preamble or "-"]) return table @@ -95,6 +96,22 @@ def _extract_preamble(source: str, env: jinja2.Environment) -> str | None: return None +def _load_template(rel_path: str, env: jinja2.Environment) -> Template: + assert env.loader, "No loader in environment" + source, abs_path, _uptodate = env.loader.get_source(env, rel_path) + assert abs_path, "Missing template path" + preamble = _extract_preamble(source, env) + return Template(Path(rel_path), Path(abs_path), source, preamble) + + +def find_template(name: str) -> Template | None: + env = _jinja_environment() + try: + return _load_template(f"{name}.{_extension}", env) + except jinja2.TemplateNotFound: + return None + + @dataclasses.dataclass(frozen=True) class Template: """An available template""" @@ -126,22 +143,6 @@ def extract_variables(self, env: jinja2.Environment) -> frozenset[str]: ast = env.parse(self.source) return frozenset(jinja2.meta.find_undeclared_variables(ast)) - @classmethod - def _load(cls, rel_path: str, env: jinja2.Environment) -> Self: - assert env.loader, "No loader in environment" - source, abs_path, _uptodate = env.loader.get_source(env, rel_path) - assert abs_path, "Missing template path" - preamble = _extract_preamble(source, env) - return cls(Path(rel_path), Path(abs_path), source, preamble) - - @classmethod - def find(cls, name: str) -> Self | None: - env = _jinja_environment() - try: - return cls._load(f"{name}.{_extension}", env) - except jinja2.TemplateNotFound: - return None - @staticmethod def local_path_for(name: str) -> Path: return _PromptFolder.LOCAL.path / Path(f"{name}.{_extension}") diff --git a/src/git_draft/store.py b/src/git_draft/store.py index 5b50924..49c7b66 100644 --- a/src/git_draft/store.py +++ b/src/git_draft/store.py @@ -1,10 +1,11 @@ """Persistent state storage""" +from collections.abc import Iterator import contextlib from datetime import datetime import functools import sqlite3 -from typing import Iterator, Self +from typing import Self from .common import ensure_state_home, package_root diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 26f2d6d..304f632 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -2,10 +2,11 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import logging from pathlib import PurePosixPath import tempfile -from typing import Callable, Protocol, Sequence, override +from typing import Protocol, override import git diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/git_draft/__init__.py b/tests/git_draft/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/git_draft/conftest.py b/tests/git_draft/conftest.py index bb2e43d..3a5f230 100644 --- a/tests/git_draft/conftest.py +++ b/tests/git_draft/conftest.py @@ -1,5 +1,5 @@ +from collections.abc import Iterator from pathlib import Path -from typing import Iterator import git import pytest diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 7d67fb3..975729d 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence import os from pathlib import Path, PurePosixPath -from typing import Sequence import git import pytest diff --git a/tests/git_draft/prompt_test.py b/tests/git_draft/prompt_test.py index 09f05cb..1d5357c 100644 --- a/tests/git_draft/prompt_test.py +++ b/tests/git_draft/prompt_test.py @@ -27,32 +27,32 @@ def setup(self) -> None: self._env = sut._jinja_environment() def test_fields(self): - tpl = sut.Template._load("includes/.file-list.jinja", self._env) + tpl = sut._load_template("includes/.file-list.jinja", self._env) assert not tpl.is_local() assert tpl.name == "includes/.file-list" assert tpl.local_path() != tpl.abs_path def test_preamble_ok(self): - tpl = sut.Template._load("add-test.jinja", self._env) + tpl = sut._load_template("add-test.jinja", self._env) assert "symbol" in tpl.preamble def test_preamble_missing(self): - tpl = sut.Template._load("includes/.file-list.jinja", self._env) + tpl = sut._load_template("includes/.file-list.jinja", self._env) assert tpl.preamble is None def test_extract_variables(self): - tpl = sut.Template._load("add-test.jinja", self._env) + tpl = sut._load_template("add-test.jinja", self._env) variables = tpl.extract_variables(self._env) assert "symbol" in variables assert "repo" not in variables def test_find_ok(self) -> None: - tpl = sut.Template.find("add-test") + tpl = sut.find_template("add-test") assert tpl assert "symbol" in tpl.source def test_find_missing(self) -> None: - assert sut.Template.find("foo") is None + assert sut.find_template("foo") is None def test_templates_table() -> None: