Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ ignore = ["E203", "E501", "E704", "W503"]
[tool.isort]
profile = "black"
force_sort_within_sections = true
lines_after_imports = 2

[tool.mypy]
disable_error_code = "import-untyped"
Expand Down
1 change: 1 addition & 0 deletions src/git_draft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .bots import Action, Bot, Toolbox


__all__ = [
"Action",
"Bot",
Expand Down
16 changes: 13 additions & 3 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .store import Store


_logger = logging.getLogger(__name__)


def new_parser() -> optparse.OptionParser:
parser = optparse.OptionParser(
prog=PROGRAM,
Expand Down Expand Up @@ -134,12 +137,19 @@ def main() -> None:
prompt, bot, checkout=opts.checkout, reset=opts.reset
)
elif command == "finalize":
drafter.finalize_draft(delete=opts.delete)
name = drafter.finalize_draft(delete=opts.delete)
print(f"Finalized {name}")
elif command == "revert":
drafter.revert_draft(delete=opts.delete)
name = drafter.revert_draft(delete=opts.delete)
print(f"Reverted {name}")
else:
raise UnreachableError()


if __name__ == "__main__":
main()
try:
main()
except Exception as err:
_logger.exception("Program failed.")
print(f"Error: {err}", file=sys.stderr)
sys.exit(1)
4 changes: 3 additions & 1 deletion src/git_draft/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import sys

from ..common import BotConfig, reindent
from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox
from ..toolbox import Operation, OperationHook, Toolbox
from .common import Action, Bot, Goal


__all__ = [
"Action",
Expand Down
90 changes: 3 additions & 87 deletions src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,10 @@
from __future__ import annotations

import dataclasses
from datetime import datetime
from pathlib import Path, PurePosixPath
from typing import Callable, Sequence
from pathlib import Path

from ..common import JSONObject, ensure_state_home


class Toolbox:
"""File-system intermediary

Note that the toolbox is not thread-safe. Concurrent operations should be
serialized by the caller.
"""

# TODO: Something similar to https://aider.chat/docs/repomap.html,
# including inferring the most important files, and allowing returning
# signature-only versions.

# TODO: Support a diff-based edit method.
# https://gist.github.com/noporpoise/16e731849eb1231e86d78f9dfeca3abc

def __init__(self, hook: OperationHook | None = None) -> None:
self.operations = list[Operation]()
self._operation_hook = hook

def _record_operation(
self, reason: str | None, tool: str, **kwargs
) -> None:
op = Operation(
tool=tool, details=kwargs, reason=reason, start=datetime.now()
)
self.operations.append(op)
if self._operation_hook:
self._operation_hook(op)

def list_files(self, reason: str | None = None) -> Sequence[PurePosixPath]:
self._record_operation(reason, "list_files")
return self._list()

def read_file(
self,
path: PurePosixPath,
reason: str | None = None,
) -> str:
self._record_operation(reason, "read_file", path=str(path))
return self._read(path)

def write_file(
self,
path: PurePosixPath,
contents: str,
reason: str | None = None,
) -> None:
self._record_operation(
reason, "write_file", path=str(path), size=len(contents)
)
return self._write(path, contents)

def delete_file(
self,
path: PurePosixPath,
reason: str | None = None,
) -> None:
self._record_operation(reason, "delete_file", path=str(path))
return self._delete(path)

def _list(self) -> Sequence[PurePosixPath]:
raise NotImplementedError()

def _read(self, path: PurePosixPath) -> str:
raise NotImplementedError()

def _write(self, path: PurePosixPath, contents: str) -> None:
raise NotImplementedError()

def _delete(self, path: PurePosixPath) -> None:
raise NotImplementedError()


@dataclasses.dataclass(frozen=True)
class Operation:
tool: str
details: JSONObject
reason: str | None
start: datetime


type OperationHook = Callable[[Operation], None]
from ..common import ensure_state_home
from ..toolbox import Toolbox


@dataclasses.dataclass(frozen=True)
Expand Down
17 changes: 10 additions & 7 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..common import JSONObject, reindent
from .common import Action, Bot, Goal, Toolbox


_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -132,7 +133,7 @@ class _ToolHandler[V]:
def __init__(self, toolbox: Toolbox) -> None:
self._toolbox = toolbox

def _on_read_file(self, path: PurePosixPath, contents: str) -> V:
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> V:
raise NotImplementedError()

def _on_write_file(self, path: PurePosixPath) -> V:
Expand Down Expand Up @@ -196,10 +197,10 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:


class _CompletionsToolHandler(_ToolHandler[str | None]):
def _on_read_file(self, path: PurePosixPath, contents: str) -> str:
return (
f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" ""
)
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str:
if contents is None:
return f"`{path}` does not exist."
return f"The contents of `{path}` are:\n\n```\n{contents}\n```\n"

def _on_write_file(self, path: PurePosixPath) -> None:
return None
Expand Down Expand Up @@ -303,8 +304,10 @@ def __init__(self, toolbox: Toolbox, call_id: str) -> None:
def _wrap(self, output: str) -> _ToolOutput:
return _ToolOutput(tool_call_id=self._call_id, output=output)

def _on_read_file(self, path: PurePosixPath, contents: str) -> _ToolOutput:
return self._wrap(contents)
def _on_read_file(
self, path: PurePosixPath, contents: str | None
) -> _ToolOutput:
return self._wrap(contents or "")

def _on_write_file(self, path: PurePosixPath) -> _ToolOutput:
return self._wrap("OK")
Expand Down
1 change: 1 addition & 0 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import xdg_base_dirs


PROGRAM = "git-draft"


Expand Down
76 changes: 16 additions & 60 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import dataclasses
import json
import logging
from pathlib import PurePosixPath
import re
import tempfile
import textwrap
import time
from typing import Match, Sequence, override
from typing import Match, Sequence

import git

from .bots import Bot, Goal, OperationHook, Toolbox
from .bots import Bot, Goal, OperationHook
from .common import random_id
from .prompt import PromptRenderer, TemplatedPrompt
from .store import Store, sql
from .toolbox import StagingToolbox


_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,53 +49,6 @@ def new_suffix():
return random_id(9)


class _Toolbox(Toolbox):
"""Git-index backed toolbox

All files are directly read from and written to the index. This allows
concurrent editing without interference.
"""

def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
super().__init__(hook)
self._repo = repo
self._written = set[str]()

@override
def _list(self) -> Sequence[PurePosixPath]:
# Show staged files.
return self._repo.git.ls_files().splitlines()

@override
def _read(self, path: PurePosixPath) -> str:
# Read the file from the index.
return self._repo.git.show(f":{path}")

@override
def _write(self, path: PurePosixPath, contents: str) -> None:
self._written.add(str(path))
# Update the index without touching the worktree.
# https://stackoverflow.com/a/25352119
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
temp.write(contents.encode("utf8"))
temp.close()
sha = self._repo.git.hash_object("-w", temp.name, path=path)
mode = 644 # TODO: Read from original file if it exists.
self._repo.git.update_index(
f"{mode},{sha},{path}", add=True, cacheinfo=True
)

def trim_index(self) -> None:
diff = self._repo.git.diff(name_only=True, cached=True)
untouched = [
path
for path in diff.splitlines()
if path and path not in self._written
]
if untouched:
self._repo.git.reset("--", *untouched)


class Drafter:
"""Draft state orchestrator"""

Expand Down Expand Up @@ -139,17 +92,19 @@ def generate_draft(

branch = _Branch.active(self._repo)
if branch:
_logger.debug("Reusing active branch %s.", branch)
self._stage_changes(sync)
_logger.debug("Reusing active branch %s.", branch)
else:
branch = self._create_branch(sync)
_logger.debug("Created branch %s.", branch)

toolbox = StagingToolbox(self._repo, self._operation_hook)
if isinstance(prompt, TemplatedPrompt):
renderer = PromptRenderer.for_repo(self._repo)
renderer = PromptRenderer.for_toolbox(toolbox)
prompt_contents = renderer.render(prompt)
else:
prompt_contents = prompt

with self._store.cursor() as cursor:
[(prompt_id,)] = cursor.execute(
sql("add-prompt"),
Expand All @@ -161,7 +116,6 @@ def generate_draft(

start_time = time.perf_counter()
goal = Goal(prompt_contents, timeout)
toolbox = _Toolbox(self._repo, self._operation_hook)
action = bot.act(goal, toolbox)
end_time = time.perf_counter()

Expand Down Expand Up @@ -201,11 +155,11 @@ def generate_draft(
if checkout:
self._repo.git.checkout("--", ".")

def finalize_draft(self, delete=False) -> None:
self._exit_draft(revert=False, delete=delete)
def finalize_draft(self, delete=False) -> str:
return self._exit_draft(revert=False, delete=delete)

def revert_draft(self, delete=False) -> None:
self._exit_draft(revert=True, delete=delete)
def revert_draft(self, delete=False) -> str:
return self._exit_draft(revert=True, delete=delete)

def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
Expand Down Expand Up @@ -241,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
ref = self._repo.index.commit("draft! sync")
return ref.hexsha

def _exit_draft(self, *, revert: bool, delete: bool) -> None:
def _exit_draft(self, *, revert: bool, delete: bool) -> str:
branch = _Branch.active(self._repo)
if not branch:
raise RuntimeError("Not currently on a draft branch")
Expand All @@ -268,7 +222,7 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
self._repo.git.reset("-N", origin_branch)
self._repo.git.checkout(origin_branch)

# Finally, we revert the relevant files if needed. If a sync commit had
# Next, we revert the relevant files if needed. If a sync commit had
# been created, we simply revert to it. Otherwise we compute which
# files have changed due to draft commits and revert only those.
if revert:
Expand All @@ -283,6 +237,8 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
if delete:
self._repo.git.branch("-D", branch.name)

return branch.name

def _changed_files(self, spec) -> Sequence[str]:
return self._repo.git.diff(spec, name_only=True).splitlines()

Expand Down
1 change: 1 addition & 0 deletions src/git_draft/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import tempfile


_default_editors = ["vim", "emacs", "nano"]


Expand Down
Loading