diff --git a/.github/actions/setup/action.yaml b/.github/actions/setup/action.yaml index 56362a2..4e2bbed 100644 --- a/.github/actions/setup/action.yaml +++ b/.github/actions/setup/action.yaml @@ -19,3 +19,8 @@ runs: - name: Lint shell: bash run: poetry run poe lint + - name: Set up git config + shell: bash + run: | + git config --global user.email test+git-draft@mtth.io + git config --global user.name tester diff --git a/README.md b/README.md index 68b17f7..a3514d5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ # `git-draft(1)` -WIP +> [!NOTE] +> WIP: Not quite functional yet. + +## Highlights + +* Concurrent editing. Continue editing while the assistant runs, without any + risks of interference. diff --git a/docs/git-draft.adoc b/docs/git-draft.adoc index 51b3915..f5b6e06 100644 --- a/docs/git-draft.adoc +++ b/docs/git-draft.adoc @@ -12,14 +12,17 @@ v{manversion} git-draft - git-friendly code assistant +IMPORTANT: _git-draft_ is WIP. +Options documented below may not be implemented yet. + == Synopsis -*git-draft* _-C_ +*git-draft* _[--generate]_ _[--prompt PROMPT]_ _[--reset]_ _[TEMPLATE [...]]_ -*git-draft* _-E_ +*git-draft* _--finalize_ _[--delete]_ -*git-draft* _-A_ +*git-draft* _--discard_ _[--delete]_ == Description @@ -28,15 +31,21 @@ _git-draft_ is a git-centric way to edit code using AI. === How it works -When you create a new draft with `git draft -C $name`, a new branch called `$branch/drafts/$name-$hash` is created (`$hash` is a random suffix used to guarantee uniqueness of branch names) and checked out. -Additionally, any uncommitted changes are automatically committed (`draft! sync`). -Once the draft is created, we can use AI to edit our code using `git draft -E`. -It expects the prompt as standard input, for example `echo "Add a test for compute_offset in chart.py" | git draft -E`. -The prompt will automatically get augmented with information about the files in the repository, and give the AI access to tools for reading and writing files. -Once the response has been received and changes, applied a commit is created (`draft! prompt: a short summary of the change`). +The workhorse command is `git draft --generate` which leverages AI to edit our code. +A prompt can be specified as standard input, for example `echo "Add a test for compute_offset in chart.py" | git draft --generate`. +If no prompt is specified and stdin is a TTY, `$EDITOR` will be opened to enter the prompt. + +If not on a draft branch, a new draft branch called `drafts/$parent/$hash` will be created (`$hash` is a random suffix used to guarantee uniqueness of branch names) and checked out. +By default any unstaged changes are then automatically added and committed (`draft! sync`). +This behavior can be disabled by passing in `--stash`, which will instead add them to the stash. +Staged changes are always committed. + +The prompt automatically gets augmented with information about the files in the repository, and give the AI access to tools for reading and writing files. +Once the response has been received and changes applied, a commit is created (`draft! prompt: a short summary of the change`). -The prompt step can be repeated as many times as needed. Once you are satisfied with the changes, run `git draft -A` to apply them. +The `--generate` step can be repeated as many times as needed. +Once you are satisfied with the changes, run `git draft --finalize` to apply them. This will check out the branch used when creating the draft, adding the final state of the draft to the worktree. Note that you can come back to an existing draft anytime (by checking its branch out), but you will not be able to apply it if its origin branch has moved since the draft was created. diff --git a/poetry.lock b/poetry.lock index a5bcb4a..d68c14c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6,7 +6,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -18,7 +18,7 @@ version = "4.8.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" -groups = ["dev"] +groups = ["main"] files = [ {file = "anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a"}, {file = "anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a"}, @@ -85,7 +85,7 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" -groups = ["dev"] +groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -112,12 +112,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["dev"] -markers = "platform_system == \"Windows\" or sys_platform == \"win32\"" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "distro" @@ -125,7 +125,7 @@ version = "1.9.0" description = "Distro - an OS platform information API" optional = false python-versions = ">=3.6" -groups = ["dev"] +groups = ["main"] files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, @@ -171,7 +171,7 @@ version = "4.0.12" description = "Git Object Database" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf"}, {file = "gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571"}, @@ -186,7 +186,7 @@ version = "3.1.44" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110"}, {file = "gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269"}, @@ -205,7 +205,7 @@ version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -217,7 +217,7 @@ version = "1.0.7" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd"}, {file = "httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c"}, @@ -239,7 +239,7 @@ version = "0.28.1" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, @@ -264,7 +264,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["dev"] +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -291,7 +291,7 @@ version = "0.8.2" description = "Fast iterable JSON parser." optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"}, {file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"}, @@ -454,7 +454,7 @@ version = "1.64.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "openai-1.64.0-py3-none-any.whl", hash = "sha256:20f85cde9e95e9fbb416e3cb5a6d3119c0b28308afd6e3cc47bf100623dac623"}, {file = "openai-1.64.0.tar.gz", hash = "sha256:2861053538704d61340da56e2f176853d19f1dc5704bc306b7597155f850d57a"}, @@ -580,7 +580,7 @@ version = "2.10.6" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, @@ -601,7 +601,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -747,7 +747,7 @@ version = "5.0.2" description = "A pure Python implementation of a sliding window memory map manager" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e"}, {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"}, @@ -759,7 +759,7 @@ version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -813,7 +813,7 @@ version = "4.67.1" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main"] files = [ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, @@ -835,7 +835,7 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -844,4 +844,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "9988f8dd414e11d19add92c87144ad96951eca1523049033dd4bc9738845cc15" +content-hash = "98ba96ddfc998c9c99546fd9c3e3ce358527be48efa24a1228b19b4658fdc195" diff --git a/pyproject.toml b/pyproject.toml index 23ce771..0f1b3e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,15 +15,15 @@ packages = [{include = 'git_draft', from = 'src'}] git-draft = 'git_draft.__main__:main' [tool.poetry.dependencies] +gitpython = '^3.1.44' +openai = '^1.64.0' python = '>=3.12,<4' [tool.poetry.group.dev.dependencies] black = '^25.1.0' flake8 = '^7.0.0' flake8-pyproject = '^1.2.3' -gitpython = '^3.1.44' mypy = '^1.2.0' -openai = '^1.64.0' poethepoet = '^0.25.0' pytest = '^7.1.2' diff --git a/src/git_draft/__init__.py b/src/git_draft/__init__.py index 41aa2ea..ad90e65 100644 --- a/src/git_draft/__init__.py +++ b/src/git_draft/__init__.py @@ -1,7 +1,11 @@ -from .actions import apply_draft, create_draft, extend_draft +from .assistant import Assistant, OpenAIAssistant +from .common import open_editor +from .manager import Manager, enclosing_repo __all__ = [ - "apply_draft", - "create_draft", - "extend_draft", + "Assistant", + "OpenAIAssistant", + "Manager", + "enclosing_repo", + "open_editor", ] diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 43193fc..71d73d0 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -1,94 +1,90 @@ from __future__ import annotations +import importlib.metadata import optparse import sys +import textwrap -from . import apply_draft, create_draft, extend_draft +from . import Manager, OpenAIAssistant, enclosing_repo, open_editor -parser = optparse.OptionParser(prog="git-draft") +EPILOG = """\ + More information via `man git-draft` and https://mtth.github.io/git-draft. +""" -parser.disable_interspersed_args() -command_group = optparse.OptionGroup( - parser, "Commands", "exactly one command must be specified" +parser = optparse.OptionParser( + prog="git-draft", + epilog=textwrap.dedent(EPILOG), + version=importlib.metadata.version("git_draft"), ) -parser.add_option_group(command_group) - - -class Command: - @classmethod - def register(cls, name: str, **kwargs) -> Command: - command = cls(name) - command_group.add_option( - command.flag, - action="callback", - callback=command, - callback_args=(name,), - **kwargs, - ) - return command - - def __init__(self, name: str) -> None: - self.name = name - self._option_group: optparse.OptionGroup | None = None - - @property - def flag(self): - return f"-{self.name[0].upper()}" - - def option_group(self) -> optparse.OptionGroup: - if not self._option_group: - self._option_group = optparse.OptionGroup( - parser, f"Optional {self.flag} flags" - ) - parser.add_option_group(self._option_group) - return self._option_group - - def __call__(self, _option, _opt, _value, parser, name) -> None: + +parser.disable_interspersed_args() + + +def add_command(name: str, **kwargs) -> None: + def callback(_option, _opt, _value, parser) -> None: parser.values.command = name + parser.add_option( + f"-{name[0].upper()}", + f"--{name}", + action="callback", + callback=callback, + **kwargs, + ) -Command.register("create", help="create a draft") -Command.register( - "extend", help="read a prompt from stdin to add to the current draft" -) +add_command("discard", help="discard all drafts associated with a branch") +add_command("finalize", help="apply the current draft to the original branch") +add_command("generate", help="draft a new change from a prompt") -apply_command = Command.register( - "apply", help="apply the current draft to the original branch" -) -apply_command.option_group().add_option( +parser.add_option( "-d", - help="delete the draft after applying", + "--delete", + help="delete the draft after finalizing or discarding", action="store_true", ) - -delete_command = Command.register( - "delete", help="delete all drafts associated with a branch" +parser.add_option( + "-p", + "--prompt", + dest="prompt", + help="draft generation prompt, read from stdin if unset", ) -delete_command.option_group().add_option( - "-b", - help="draft source branch [default: active branch]", - type="string", - metavar="BRANCH", +parser.add_option( + "-r", + "--reset", + help="reset index before generating a new draft", + action="store_true", ) +EDITOR_PLACEHOLDER = """\ + Enter your prompt here... +""" + + def main() -> None: (opts, args) = parser.parse_args() - command = getattr(opts, "command", None) - if command == "create": - create_draft() - elif command == "extend": - prompt = sys.stdin.read() - extend_draft(prompt) - elif command == "apply": - apply_draft() - elif command == "delete": - print("Deleting draft...") + + repo = enclosing_repo() + manager = Manager(repo) + + command = getattr(opts, "command", "generate") + if command == "generate": + prompt = opts.prompt + if not prompt: + if sys.stdin.isatty(): + prompt = open_editor(textwrap.dedent(EDITOR_PLACEHOLDER)) + else: + prompt = sys.stdin.read() + manager.generate_draft(prompt, OpenAIAssistant(), reset=opts.reset) + elif command == "finalize": + manager.finalize_draft(delete=opts.delete) + elif command == "discard": + manager.discard_draft(delete=opts.delete) else: - parser.error("missing command") + assert False, "unreachable" if __name__ == "__main__": diff --git a/src/git_draft/actions.py b/src/git_draft/actions.py deleted file mode 100644 index 87ac772..0000000 --- a/src/git_draft/actions.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import annotations - -import dataclasses -import git -import random -from pathlib import PurePosixPath -import re -import string -import tempfile -from typing import Match, Sequence - -from .backend import NewFileBackend - - -def _enclosing_repo() -> git.Repo: - return git.Repo(search_parent_directories=True) - - -_random = random.Random() - -_SUFFIX_LENGTH = 8 - -_branch_name_pattern = re.compile(r"drafts/(.+)/(\w+)") - - -@dataclasses.dataclass(frozen=True) -class _DraftBranch: - parent: str - suffix: str - repo: git.Repo - - def __str__(self) -> str: - return f"drafts/{self.parent}/{self.suffix}" - - @classmethod - def create(cls, repo: git.Repo) -> _DraftBranch: - if not repo.active_branch: - raise RuntimeError("No currently active branch") - suffix = "".join( - _random.choice(string.ascii_lowercase + string.digits) - for _ in range(_SUFFIX_LENGTH) - ) - return cls(repo.active_branch.name, suffix, repo) - - @classmethod - def active(cls, repo: git.Repo) -> _DraftBranch: - match: Match | None = None - if repo.active_branch: - match = _branch_name_pattern.fullmatch(repo.active_branch.name) - if not match: - raise RuntimeError("Not currently on a draft branch") - return _DraftBranch(match[1], match[2], repo) - - -@dataclasses.dataclass(frozen=True) -class _CommitNotes: - # https://stackoverflow.com/a/40496777 - pass - - -def create_draft() -> None: - repo = _enclosing_repo() - draft_branch = _DraftBranch.create(repo) - ref = repo.create_head(str(draft_branch)) - repo.git.checkout(ref) - - -class _Toolbox: - def __init__(self, repo: git.Repo) -> None: - self._repo = repo - - def list_files(self) -> Sequence[PurePosixPath]: - # Show staged files. - return self._repo.git.ls_files() - - def read_file(self, path: PurePosixPath) -> str: - # Read the file from the index. - return self._repo.git.show(f":{path}") - - def write_file(self, path: PurePosixPath, data: str) -> None: - # Update the index without touching the worktree. - # https://stackoverflow.com/a/25352119 - with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: - temp.write(data.encode("utf8")) - temp.close() - sha = self._repo.git.hash_object("-w", "--path", path, temp.name) - mode = 644 # TODO: Read from original file if it exists. - self._repo.git.update_index( - "--add", "--cacheinfo", f"{mode},{sha},{path}" - ) - - -def extend_draft(prompt: str) -> None: - repo = _enclosing_repo() - _ = _DraftBranch.active(repo) - - if repo.is_dirty(): - repo.git.add(all=True) - repo.index.commit("draft! sync") - - NewFileBackend().run(_Toolbox(repo)) - - -def apply_draft(delete=False) -> None: - repo = _enclosing_repo() - branch = _DraftBranch.active(repo) - - # TODO: Check that parent has not moved. We could do this for example by - # adding a note to the draft branch with the original branch's commit ref. - - # https://stackoverflow.com/a/15993574 - repo.git.checkout("--detach") - repo.git.reset("--soft", branch.parent) - repo.git.checkout(branch.parent) - - if delete: - repo.git.branch("-D", str(branch)) diff --git a/src/git_draft/assistant.py b/src/git_draft/assistant.py new file mode 100644 index 0000000..3557984 --- /dev/null +++ b/src/git_draft/assistant.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import dataclasses +import openai +from pathlib import PurePosixPath +import textwrap +from typing import Protocol, Sequence + + +class Toolbox(Protocol): + def list_files(self) -> Sequence[PurePosixPath]: ... + def read_file(self, path: PurePosixPath) -> str: ... + def write_file(self, path: PurePosixPath, data: str) -> None: ... + + +@dataclasses.dataclass(frozen=True) +class Session: + token_count: int + calls: list[Call] + + +@dataclasses.dataclass(frozen=True) +class Call: + usage: openai.types.CompletionUsage | None + + +class Assistant: + def run(self, prompt: str, toolbox: Toolbox) -> Session: + raise NotImplementedError() + + +# https://aider.chat/docs/more-info.html +# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py +_SYSTEM_PROMPT = textwrap.dedent( + """ + You are an expert software engineer, who writes correct and concise code. +""" +) + + +class OpenAIAssistant(Assistant): + def __init__(self) -> None: + self._client = openai.OpenAI() + + def run(self, prompt: str, toolbox: Toolbox) -> Session: + # TODO: Switch to the thread run API, using tools to leverage toolbox + # methods. + # https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps + # https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py + completion = self._client.chat.completions.create( + messages=[ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + model="gpt-4o", + ) + content = completion.choices[0].message.content or "" + toolbox.write_file(PurePosixPath(f"{completion.id}.txt"), content) + return Session(0, calls=[Call(completion.usage)]) diff --git a/src/git_draft/backend.py b/src/git_draft/backend.py deleted file mode 100644 index f98b7f8..0000000 --- a/src/git_draft/backend.py +++ /dev/null @@ -1,26 +0,0 @@ -from pathlib import PurePosixPath -from typing import Protocol, Sequence - - -class Toolbox(Protocol): - def list_files(self) -> Sequence[PurePosixPath]: ... - def read_file(self, path: PurePosixPath) -> str: ... - def write_file(self, path: PurePosixPath, data: str) -> None: ... - - -class Backend: - def run(self, toolbox: Toolbox) -> None: ... - - -class NewFileBackend(Backend): - def run(self, toolbox: Toolbox) -> None: - # send request to backend... - import time - - time.sleep(2) - - # Add files to index. - import random - - name = f"foo-{random.randint(1, 100)}" - toolbox.write_file(PurePosixPath(name), "hello!\n") diff --git a/src/git_draft/common.py b/src/git_draft/common.py new file mode 100644 index 0000000..3ece5cf --- /dev/null +++ b/src/git_draft/common.py @@ -0,0 +1,45 @@ +import os +import shutil +import subprocess +import sys +import tempfile + + +_default_editors = ["vim", "emacs", "nano"] + + +def _guess_editor_binpath() -> str: + editor = os.environ.get("EDITOR") + if editor: + return shutil.which(editor) or "" + for editor in _default_editors: + binpath = shutil.which(editor) + if binpath: + return binpath + return "" + + +def _get_tty_filename(): + if sys.platform == "win32": + return "CON:" + return "/dev/tty" + + +def open_editor(placeholder="") -> str: + with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: + binpath = _guess_editor_binpath() + if not binpath: + raise ValueError("Editor unavailable") + + if placeholder: + with open(temp.name, "w") as writer: + writer.write(placeholder) + + stdout = open(_get_tty_filename(), "wb") + proc = subprocess.Popen( + [binpath, temp.name], close_fds=True, stdout=stdout + ) + proc.communicate() + + with open(temp.name, mode="r") as reader: + return reader.read() diff --git a/src/git_draft/manager.py b/src/git_draft/manager.py new file mode 100644 index 0000000..41c2835 --- /dev/null +++ b/src/git_draft/manager.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import dataclasses +import git +import json +from pathlib import PurePosixPath +import re +import tempfile +from typing import Callable, ClassVar, Match, Self, Sequence + +from .assistant import Assistant + + +def enclosing_repo(path: str | None = None) -> git.Repo: + """Returns the repository to which the given path belongs""" + return git.Repo(path, search_parent_directories=True) + + +class _Note: + """Structured metadata attached to a commit""" + + # https://stackoverflow.com/a/40496777 + + __prefix: ClassVar[str] + + def __init_subclass__(cls, name) -> None: + cls.__prefix = f"{name}: " + + @classmethod + def read(cls, repo: git.Repo, ref: str) -> Self | None: + for line in repo.git.notes("show", ref).splitlines(): + if line.startswith(cls.__prefix): + data = json.loads(line[len(cls.__prefix) :]) + return cls(**data) + return None + + def write(self, repo: git.Repo, ref: str) -> None: + assert dataclasses.is_dataclass(self) + data = dataclasses.asdict(self) + value = json.dumps(data, separators=(",", ":")) + repo.git.notes( + "append", "--no-separator", "-m", f"{self.__prefix}{value}", ref + ) + + +@dataclasses.dataclass(frozen=True) +class _InitNote(_Note, name="draft-init"): + """Information about the current draft's branch""" + + origin_branch: str + sync_sha: str | None + + +@dataclasses.dataclass(frozen=True) +class _SessionNote(_Note, name="draft-session"): + """Information about a commit's underlying assistant session""" + + token_count: int + walltime: float + + +@dataclasses.dataclass(frozen=True) +class _Branch: + """Draft branch""" + + _name_pattern = re.compile(r"drafts/(.+)") + + init_shortsha: str + init_note: _InitNote + + @property + def name(self) -> str: + return f"drafts/{self.init_shortsha}" + + def needs_rebase(self, repo: git.Repo) -> bool: + if not self.init_note.sync_sha: + return False + init_commit = repo.commit(self.init_shortsha) + (origin_commit,) = init_commit.parents + head_commit = repo.commit(self.init_note.origin_branch) + return origin_commit == head_commit + + @classmethod + def create(cls, repo: git.Repo, sync: Callable[[], str | None]) -> _Branch: + if not repo.active_branch: + raise RuntimeError("No currently active branch") + origin_branch = repo.active_branch.name + + repo.git.checkout("--detach") + commit = repo.index.commit("draft! init") + init_shortsha = commit.hexsha[:7] + init_note = _InitNote(origin_branch, sync()) + init_note.write(repo, init_shortsha) + + branch = _Branch(init_shortsha, init_note) + branch_ref = repo.create_head(branch.name) + repo.git.checkout(branch_ref) + return branch + + @classmethod + def active(cls, repo: git.Repo) -> _Branch | None: + match: Match | None = None + if repo.active_branch: + match = cls._name_pattern.fullmatch(repo.active_branch.name) + if not match: + return None + init_shortsha = match[1] + init_note = _InitNote.read(repo, init_shortsha) + assert init_note + return _Branch(init_shortsha, init_note) + + +class _Toolbox: + def __init__(self, repo: git.Repo) -> None: + self._repo = repo + + def list_files(self) -> Sequence[PurePosixPath]: + # Show staged files. + return self._repo.git.ls_files() + + def read_file(self, path: PurePosixPath) -> str: + # Read the file from the index. + return self._repo.git.show(f":{path}") + + def write_file(self, path: PurePosixPath, data: str) -> None: + # Update the index without touching the worktree. + # https://stackoverflow.com/a/25352119 + with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: + temp.write(data.encode("utf8")) + temp.close() + sha = self._repo.git.hash_object("-w", "--path", path, temp.name) + mode = 644 # TODO: Read from original file if it exists. + self._repo.git.update_index( + "--add", "--cacheinfo", f"{mode},{sha},{path}" + ) + + +class Manager: + def __init__(self, repo: git.Repo) -> None: + self._repo = repo + + def generate_draft( + self, prompt: str, assistant: Assistant, reset=False + ) -> None: + if not prompt.strip(): + raise ValueError("Empty prompt") + if self._repo.is_dirty(working_tree=False): + if not reset: + raise ValueError("Please commit or reset any staged changes") + self._repo.index.reset() + + branch = _Branch.active(self._repo) + if branch: + self._sync() + else: + branch = _Branch.create(self._repo, self._sync) + + assistant.run(prompt, _Toolbox(self._repo)) + self._repo.index.commit(f"draft! prompt\n\n{prompt}") + + def finalize_draft(self, delete=False) -> None: + self._exit_draft(True, delete=delete) + + def discard_draft(self, delete=False) -> None: + self._exit_draft(False, delete=delete) + + def _sync(self) -> str | None: + if not self._repo.is_dirty(untracked_files=True): + return None + self._repo.git.add(all=True) + ref = self._repo.index.commit("draft! sync") + return ref.hexsha + + def _exit_draft(self, apply: bool, delete=False) -> None: + branch = _Branch.active(self._repo) + if not branch: + raise RuntimeError("Not currently on a draft branch") + if not apply and branch.needs_rebase(self._repo): + raise ValueError("Parent branch has moved, please rebase") + + # https://stackoverflow.com/a/15993574 + note = branch.init_note + self._repo.git.checkout("--detach") + if apply: + # We discard index (internal) changes + self._repo.git.reset(note.origin_branch) + else: + self._repo.git.reset("--hard", note.sync_sha or note.origin_branch) + self._repo.git.checkout(note.origin_branch) + + if delete: + self._repo.git.branch("-D", branch.name) diff --git a/tests/git_draft/assistant_test.py b/tests/git_draft/assistant_test.py new file mode 100644 index 0000000..2526843 --- /dev/null +++ b/tests/git_draft/assistant_test.py @@ -0,0 +1,5 @@ +import git_draft.assistant as sut + + +def test_assistant(): + assert sut.Assistant() diff --git a/tests/git_draft/backend_test.py b/tests/git_draft/backend_test.py deleted file mode 100644 index 72360a7..0000000 --- a/tests/git_draft/backend_test.py +++ /dev/null @@ -1,5 +0,0 @@ -import git_draft.backend as sut - - -def test_backend(): - assert sut.Backend() diff --git a/tests/git_draft/manager_test.py b/tests/git_draft/manager_test.py new file mode 100644 index 0000000..373fb9c --- /dev/null +++ b/tests/git_draft/manager_test.py @@ -0,0 +1,66 @@ +import dataclasses +import git +from pathlib import PurePosixPath +import pytest +import tempfile +from typing import Iterator + +from git_draft.assistant import Assistant, Session, Toolbox +import git_draft.manager as sut + + +@pytest.fixture +def repo() -> Iterator[git.Repo]: + with tempfile.TemporaryDirectory() as name: + repo = git.Repo.init(name, initial_branch="main") + repo.index.commit("init") + yield repo + + +@dataclasses.dataclass(frozen=True) +class _FakeNote(sut._Note, name="draft-test"): + value: int + + +class TestNote: + def test_write_one(self, repo: git.Repo) -> None: + note = _FakeNote(2) + note.write(repo, "main") + data = repo.git.notes("show", "main") + assert data == 'draft-test: {"value":2}' + + def test_write_read_one(self, repo: git.Repo) -> None: + note = _FakeNote(1) + note.write(repo, "main") + assert note == _FakeNote.read(repo, "main") + + def test_write_multiple(self, repo: git.Repo) -> None: + _FakeNote(1).write(repo, "main") + _FakeNote(2).write(repo, "main") + data = repo.git.notes("show", "main") + assert data == "\n".join( + [ + 'draft-test: {"value":1}', + 'draft-test: {"value":2}', + ] + ) + + +class _FakeAssistant(Assistant): + def run(self, prompt: str, toolbox: Toolbox) -> Session: + toolbox.write_file(PurePosixPath("PROMPT"), prompt) + return Session(len(prompt), []) + + +class TestManager: + def test_generate_draft(self, repo: git.Repo) -> None: + manager = sut.Manager(repo) + manager.generate_draft("hello", _FakeAssistant()) + commits = list(repo.iter_commits()) + assert len(commits) == 3 + + def test_generate_then_discard_draft(self, repo: git.Repo) -> None: + manager = sut.Manager(repo) + manager.generate_draft("hello", _FakeAssistant()) + manager.discard_draft() + assert len(list(repo.iter_commits())) == 1