Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,33 @@ log_level = "DEBUG"
line-length = 79

[tool.ruff.lint]
select = ["D", "E", "F", "I", "N", "PL", "RUF", "SIM"]
ignore = ["D102", "D103", "D105", "D107", "D415", "PLR2004"]
select = [
"A",
"ANN",
"ARG",
"D",
"E",
"ERA",
"F",
"I",
"INP",
"LOG",
"N",
"PL",
"RUF",
"SIM",
"SLF",
"T20",
"TD",
"UP",
"W"
]
ignore = [
"ANN003", "ANN401",
"D102", "D103", "D105", "D107", "D415",
"PLR2004",
"TD002", "TD003",
]

[tool.ruff.lint.isort]
force-sort-within-sections = true
Expand All @@ -105,4 +130,5 @@ lines-after-imports = 2
convention = "google"

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["D"]
"__main__.py" = ["T20"]
"tests/**" = ["ANN", "D", "SLF"]
16 changes: 11 additions & 5 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from __future__ import annotations

from collections.abc import Sequence
import importlib.metadata
import logging
import optparse
from pathlib import Path, PurePosixPath
import sys
from typing import Sequence

from .bots import load_bot
from .common import PROGRAM, Config, UnreachableError, ensure_state_home
from .drafter import Drafter
from .editor import open_editor
from .prompt import Template, TemplatedPrompt, templates_table
from .prompt import Template, TemplatedPrompt, find_template, templates_table
from .store import Store
from .toolbox import ToolVisitor

Expand Down Expand Up @@ -41,7 +41,13 @@ def new_parser() -> optparse.OptionParser:
)

def add_command(name: str, short: str | None = None, **kwargs) -> None:
def callback(_option, _opt, _value, parser) -> None:
def callback(
_option: object,
_opt: object,
_value: object,
parser: optparse.OptionParser,
) -> None:
assert parser.values
parser.values.command = name

parser.add_option(
Expand Down Expand Up @@ -222,11 +228,11 @@ def main() -> None: # noqa: PLR0912 PLR0915
if table:
print(table.to_json() if opts.json else table)
elif command == "show-prompts":
raise NotImplementedError() # TODO
raise NotImplementedError() # TODO: Implement
elif command == "show-templates":
if args:
name = args[0]
tpl = Template.find(name)
tpl = find_template(name)
if opts.edit:
if tpl:
edit(path=tpl.local_path(), text=tpl.source)
Expand Down
6 changes: 3 additions & 3 deletions src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class Action:
request_count: int | None = None
token_count: int | None = None

def increment_request_count(self, n=1, init=False) -> None:
def increment_request_count(self, n: int = 1, init: bool = False) -> None:
self._increment("request_count", n, init)

def increment_token_count(self, n, init=False) -> None:
def increment_token_count(self, n: int, init: bool = False) -> None:
self._increment("token_count", n, init)

def _increment(self, attr: str, count: int, init: bool) -> None:
Expand All @@ -48,7 +48,7 @@ class Bot:
"""Code assistant bot"""

@classmethod
def state_folder_path(cls, ensure_exists=False) -> Path:
def state_folder_path(cls, ensure_exists: bool = False) -> Path:
"""Returns a path unique to this bot class

The path can be used to store data specific to this bot implementation.
Expand Down
19 changes: 10 additions & 9 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
"""

from collections.abc import Mapping, Sequence
import json
import logging
import os
from pathlib import PurePosixPath
from typing import Any, Mapping, Self, Sequence, TypedDict, override
from typing import Any, Self, TypedDict, override

import openai

Expand Down Expand Up @@ -61,7 +62,7 @@ def _param(
name: str,
description: str,
inputs: Mapping[str, Any] | None = None,
required_inputs: Sequence[str] | None = None,
_required_inputs: Sequence[str] | None = None,
) -> openai.types.beta.FunctionToolParam:
param: openai.types.beta.FunctionToolParam = {
"type": "function",
Expand Down Expand Up @@ -225,10 +226,10 @@ def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str:
return f"`{path}` does not exist."
return f"The contents of `{path}` are:\n\n```\n{contents}\n```\n"

def _on_write_file(self, path: PurePosixPath) -> None:
def _on_write_file(self, _path: PurePosixPath) -> None:
return None

def _on_delete_file(self, path: PurePosixPath) -> None:
def _on_delete_file(self, _path: PurePosixPath) -> None:
return None

def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str:
Expand Down Expand Up @@ -316,7 +317,7 @@ def on_run_step_done(
else:
_logger.warning("Missing usage in threads run step")

def _handle_action(self, run_id: str, data: Any) -> None:
def _handle_action(self, _run_id: str, data: Any) -> None:
tool_outputs = list[Any]()
for tool in data.required_action.submit_tool_outputs.tool_calls:
handler = _ThreadToolHandler(self._toolbox, tool.id)
Expand Down Expand Up @@ -347,15 +348,15 @@ def _wrap(self, output: str) -> _ToolOutput:
return _ToolOutput(tool_call_id=self._call_id, output=output)

def _on_read_file(
self, path: PurePosixPath, contents: str | None
self, _path: PurePosixPath, contents: str | None
) -> _ToolOutput:
return self._wrap(contents or "")

def _on_write_file(self, path: PurePosixPath) -> _ToolOutput:
def _on_write_file(self, _path: PurePosixPath) -> _ToolOutput:
return self._wrap("OK")

def _on_delete_file(self, path: PurePosixPath) -> _ToolOutput:
def _on_delete_file(self, _path: PurePosixPath) -> _ToolOutput:
return self._wrap("OK")

def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput:
return self._wrap("\n".join((str(p) for p in paths)))
return self._wrap("\n".join(str(p) for p in paths))
7 changes: 4 additions & 3 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
import dataclasses
import itertools
import logging
Expand All @@ -11,7 +12,7 @@
import string
import textwrap
import tomllib
from typing import Any, ClassVar, Mapping, Self, Sequence, Type
from typing import Any, ClassVar, Self

import prettytable
import xdg_base_dirs
Expand Down Expand Up @@ -84,7 +85,7 @@ class UnreachableError(RuntimeError):
"""Indicates unreachable code was unexpectedly executed"""


def reindent(s: str, width=0) -> str:
def reindent(s: str, width: int = 0) -> str:
"""Reindents text by dedenting and optionally wrapping paragraphs"""
paragraphs = (
" ".join(textwrap.dedent("\n".join(g)).splitlines())
Expand All @@ -96,7 +97,7 @@ def reindent(s: str, width=0) -> str:
)


def qualified_class_name(cls: Type) -> str:
def qualified_class_name(cls: type) -> str:
name = cls.__qualname__
return f"{cls.__module__}.{name}" if cls.__module__ else name

Expand Down
13 changes: 8 additions & 5 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Callable, Sequence
import dataclasses
from datetime import datetime
import json
Expand All @@ -10,9 +11,9 @@
import os.path as osp
from pathlib import PurePosixPath
import re
from re import Match
import textwrap
import time
from typing import Callable, Match, Sequence

import git

Expand Down Expand Up @@ -53,7 +54,7 @@ def active(cls, repo: git.Repo, name: str | None = None) -> _Branch | None:
return _Branch(match[1])

@staticmethod
def new_suffix():
def new_suffix() -> str:
return random_id(9)


Expand Down Expand Up @@ -85,7 +86,7 @@ def generate_draft( # noqa: PLR0913
timeout: float | None = None,
) -> str:
if timeout is not None:
raise NotImplementedError() # TODO
raise NotImplementedError() # TODO: Implement

if self._repo.is_dirty(working_tree=False):
if not reset:
Expand Down Expand Up @@ -174,7 +175,9 @@ def generate_draft( # noqa: PLR0913
_logger.info("Completed generation for %s.", branch)
return str(branch)

def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
def exit_draft(
self, *, revert: bool, clean: bool = False, delete: bool = False
) -> str:
branch = _Branch.active(self._repo)
if not branch:
raise RuntimeError("Not currently on a draft branch")
Expand Down Expand Up @@ -312,7 +315,7 @@ def _untracked(self) -> frozenset[str]:
text = self._repo.git.ls_files(exclude_standard=True, others=True)
return frozenset(text.splitlines())

def _delta(self, spec) -> _Delta:
def _delta(self, spec: str) -> _Delta:
changed = list[str]()
deleted = list[str]()
for line in self._repo.git.diff(spec, name_status=True).splitlines():
Expand Down
11 changes: 8 additions & 3 deletions src/git_draft/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@ def _guess_editor_binpath() -> str:
return ""


def _get_tty_filename():
def _get_tty_filename() -> str:
return "CON:" if sys.platform == "win32" else "/dev/tty"


def open_editor(text="", path: Path | None = None, *, _open_tty=open) -> str:
def open_editor(
text: str = "",
path: Path | None = None,
*,
_open_tty=open, # noqa
) -> str:
"""Open an editor to edit a file and return its contents

The method returns once the editor is closed. It respects the `$EDITOR`
Expand All @@ -46,7 +51,7 @@ def edit(path: str) -> str:
proc = subprocess.Popen([binpath, path], close_fds=True, stdout=stdout)
proc.communicate()

with open(path, mode="r") as reader:
with open(path) as reader:
return reader.read()

if path:
Expand Down
37 changes: 19 additions & 18 deletions src/git_draft/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

from collections.abc import Mapping
import dataclasses
import enum
import itertools
import os
from pathlib import Path
from typing import Mapping, Self
from typing import Self

import jinja2
import jinja2.meta
Expand Down Expand Up @@ -72,7 +73,7 @@ def templates_table() -> Table:
for rel_path in env.list_templates(extensions=[_extension]):
if any(p.startswith(".") for p in rel_path.split(os.sep)):
continue
tpl = Template._load(rel_path, env)
tpl = _load_template(rel_path, env)
local = "y" if tpl.is_local() else "n"
table.data.add_row([tpl.name, local, tpl.preamble or "-"])
return table
Expand All @@ -95,6 +96,22 @@ def _extract_preamble(source: str, env: jinja2.Environment) -> str | None:
return None


def _load_template(rel_path: str, env: jinja2.Environment) -> Template:
assert env.loader, "No loader in environment"
source, abs_path, _uptodate = env.loader.get_source(env, rel_path)
assert abs_path, "Missing template path"
preamble = _extract_preamble(source, env)
return Template(Path(rel_path), Path(abs_path), source, preamble)


def find_template(name: str) -> Template | None:
env = _jinja_environment()
try:
return _load_template(f"{name}.{_extension}", env)
except jinja2.TemplateNotFound:
return None


@dataclasses.dataclass(frozen=True)
class Template:
"""An available template"""
Expand Down Expand Up @@ -126,22 +143,6 @@ def extract_variables(self, env: jinja2.Environment) -> frozenset[str]:
ast = env.parse(self.source)
return frozenset(jinja2.meta.find_undeclared_variables(ast))

@classmethod
def _load(cls, rel_path: str, env: jinja2.Environment) -> Self:
assert env.loader, "No loader in environment"
source, abs_path, _uptodate = env.loader.get_source(env, rel_path)
assert abs_path, "Missing template path"
preamble = _extract_preamble(source, env)
return cls(Path(rel_path), Path(abs_path), source, preamble)

@classmethod
def find(cls, name: str) -> Self | None:
env = _jinja_environment()
try:
return cls._load(f"{name}.{_extension}", env)
except jinja2.TemplateNotFound:
return None

@staticmethod
def local_path_for(name: str) -> Path:
return _PromptFolder.LOCAL.path / Path(f"{name}.{_extension}")
Expand Down
3 changes: 2 additions & 1 deletion src/git_draft/store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Persistent state storage"""

from collections.abc import Iterator
import contextlib
from datetime import datetime
import functools
import sqlite3
from typing import Iterator, Self
from typing import Self

from .common import ensure_state_home, package_root

Expand Down
3 changes: 2 additions & 1 deletion src/git_draft/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from __future__ import annotations

from collections.abc import Callable, Sequence
import logging
from pathlib import PurePosixPath
import tempfile
from typing import Callable, Protocol, Sequence, override
from typing import Protocol, override

import git

Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/git_draft/__init__.py
Empty file.
Loading