diff --git a/README.md b/README.md index a1ea604..c5d897d 100644 --- a/README.md +++ b/README.md @@ -24,3 +24,6 @@ pipx install git-draft[openai] * Mechanism for reporting feedback from a bot, and possibly allowing user to interactively respond. * Add MCP bot. +* Store configuration in `gitconfig` as `draft.*` entries. This can be used for + example to set a unique repo ID, and default bot to use per repo (or more + widely via shared `gitconfig` files). diff --git a/docs/git-draft.adoc b/docs/git-draft.adoc index a30d8e3..e41ec9f 100644 --- a/docs/git-draft.adoc +++ b/docs/git-draft.adoc @@ -19,9 +19,8 @@ IMPORTANT: `git-draft` is WIP. [verse] git draft [options] [--generate] [--accept... | --no-accept] [--bot BOT] - [--edit] [--reset | --no-reset] [--sync | --no-sync] - [TEMPLATE [VARIABLE...]] -git draft [options] --finalize [--delete] [--sync | --no-sync] + [--edit] [--reset | --no-reset] [TEMPLATE [VARIABLE...]] +git draft [options] --finalize [--delete] git draft [options] --show-drafts [--json] git draft [options] --show-prompts [--json] [PROMPT] git draft [options] --show-templates [--json | [--edit] TEMPLATE] @@ -96,11 +95,6 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] Lists available templates. With an template name argument, displays the corresponding template's contents or, if the `--edit` option is set, opens an interactive editor. --s:: ---sync:: ---no-sync:: - Create a sync commit with any changes. - -t TIMEOUT:: --timeout=TIMEOUT:: Action timeout. diff --git a/poetry.lock b/poetry.lock index fcbcbcb..fb63144 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -151,40 +151,6 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] -[[package]] -name = "gitdb" -version = "4.0.12" -description = "Git Object Database" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf"}, - {file = "gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571"}, -] - -[package.dependencies] -smmap = ">=3.0.1,<6" - -[[package]] -name = "gitpython" -version = "3.1.44" -description = "GitPython is a Python library used to interact with Git repositories" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110"}, - {file = "gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269"}, -] - -[package.dependencies] -gitdb = ">=4.0.1,<5" - -[package.extras] -doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] -test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] - [[package]] name = "h11" version = "0.14.0" @@ -805,18 +771,6 @@ files = [ {file = "ruff-0.11.0.tar.gz", hash = "sha256:e55c620690a4a7ee6f1cccb256ec2157dc597d109400ae75bbf944fc9d6462e2"}, ] -[[package]] -name = "smmap" -version = "5.0.2" -description = "A pure Python implementation of a sliding window memory map manager" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e"}, - {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"}, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -938,4 +892,4 @@ openai = ["openai"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "d35f69ddc3c01cd9c29af5384d3ed089c0b84fdc3b723c155b423cdf5206e4d1" +content-hash = "bfa0a5826ac4f23cdbfcc642ea9c4a0965efd918c3123397aef01fdacab3b08f" diff --git a/pyproject.toml b/pyproject.toml index b41b59d..667ff06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ readme = "README.md" dynamic = ["version"] requires-python = ">=3.12" dependencies = [ - "gitpython (>=3.1.44,<4)", "jinja2 (>=3.1.5,<4)", "prettytable (>=3.15.1,<4)", "xdg-base-dirs (>=6.0.2,<7)", diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 1265469..4c3fc99 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -94,12 +94,6 @@ def callback( help="use JSON for table output", action="store_true", ) - parser.add_option( - "-s", - "--sync", - help="commit prior worktree changes separately", - action="store_true", - ) parser.add_option( "--no-accept", @@ -114,12 +108,6 @@ def callback( dest="reset", action="store_false", ) - parser.add_option( - "--no-sync", - help="do not commit intermediate worktree changes", - dest="sync", - action="store_false", - ) parser.add_option( "--reset", help="reset index before generating a new draft", @@ -160,7 +148,7 @@ def on_rename_file( self, src_path: PurePosixPath, dst_path: PurePosixPath, - _reason: str | None + _reason: str | None, ) -> None: print(f"Renamed {src_path} to {dst_path}.") @@ -233,7 +221,6 @@ def main() -> None: # noqa: PLR0912 PLR0915 prompt_transform=open_editor if editable else None, tool_visitors=[ToolPrinter()], reset=config.reset if opts.reset is None else opts.reset, - sync=config.sync if opts.sync is None else opts.sync, ) match accept: case Accept.MANUAL: @@ -245,10 +232,7 @@ def main() -> None: # noqa: PLR0912 PLR0915 case _: raise UnreachableError() case "finalize": - draft = drafter.finalize_draft( - delete=opts.delete, - sync=config.sync if opts.sync is None else opts.sync, - ) + draft = drafter.finalize_draft(delete=opts.delete) print(f"Finalized {draft.branch_name}.") case "show-drafts": table = drafter.history_table(args[0] if args else None) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index d21b533..7af8f9c 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -42,7 +42,6 @@ class Config: bots: Sequence[BotConfig] = dataclasses.field(default_factory=lambda: []) log_level: int = logging.INFO reset: bool = True - sync: bool = False @staticmethod def folder_path() -> Path: @@ -74,6 +73,17 @@ class BotConfig: pythonpath: str | None = None +type RepoID = str + + +@dataclasses.dataclass(frozen=True) +class RepoConfig: # TODO: Use + """Repository-specific config""" + + repo_id: str + bot_name: str | None = None + + def config_string(arg: str) -> str: """Dereferences environment value if the input starts with `$`""" return os.environ[arg[1:]] if arg and arg.startswith("$") else arg diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index b3b7401..f5828e8 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -8,17 +8,15 @@ import enum import json import logging -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath import re from re import Match -import tempfile import textwrap import time -import git - from .bots import Action, Bot, Goal from .common import JSONObject, Table, qualified_class_name, random_id +from .git import Commit, Repo from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql from .toolbox import StagingToolbox, ToolVisitor @@ -47,22 +45,23 @@ class Draft: class _Branch: """Draft branch""" - _pattern = re.compile(r"draft/(.+)") + _pattern = re.compile(r"drafts/(.+)") - suffix: str + folio_id: str @property def name(self) -> str: - return f"draft/{self.suffix}" + return f"drafts/{self.folio_id}" def __str__(self) -> str: return self.name @classmethod - def active(cls, repo: git.Repo, name: str | None = None) -> _Branch | None: + def active(cls, repo: Repo, name: str | None = None) -> _Branch | None: match: Match | None = None - if name or not repo.head.is_detached: - match = cls._pattern.fullmatch(name or repo.active_branch.name) + active_branch = name or repo.active_branch() + if active_branch: + match = cls._pattern.fullmatch(active_branch) if not match: if name: raise ValueError(f"Not a valid draft branch name: {name!r}") @@ -77,7 +76,7 @@ def new_suffix() -> str: class Drafter: """Draft state orchestrator""" - def __init__(self, store: Store, repo: git.Repo) -> None: + def __init__(self, store: Store, repo: Repo) -> None: with store.cursor() as cursor: cursor.executescript(sql("create-tables")) self._store = store @@ -85,10 +84,7 @@ def __init__(self, store: Store, repo: git.Repo) -> None: @classmethod def create(cls, store: Store, path: str | None = None) -> Drafter: - try: - return cls(store, git.Repo(path, search_parent_directories=True)) - except git.NoSuchPathError: - raise ValueError(f"No git repository at {path}") + return cls(store, Repo.enclosing(Path(path) if path else Path.cwd())) def generate_draft( # noqa: PLR0913 self, @@ -98,25 +94,24 @@ def generate_draft( # noqa: PLR0913 bot_name: str | None = None, prompt_transform: Callable[[str], str] | None = None, reset: bool = False, - sync: bool = False, timeout: float | None = None, tool_visitors: Sequence[ToolVisitor] | None = None, ) -> Draft: if timeout is not None: raise NotImplementedError() # TODO: Implement - if self._repo.is_dirty(working_tree=False): + if self._repo.has_staged_changes(): if not reset: raise ValueError("Please commit or reset any staged changes") - self._repo.index.reset() + self._repo.git("reset") # Ensure that we are on a draft branch. branch = _Branch.active(self._repo) if branch: - self._stage_repo(sync) + self._stage_repo() _logger.debug("Reusing active branch %s.", branch) else: - branch = self._create_branch(sync) + branch = self._create_branch() _logger.debug("Created branch %s.", branch) # Handle prompt templating and editing. @@ -125,7 +120,7 @@ def generate_draft( # noqa: PLR0913 [(prompt_id,)] = cursor.execute( sql("add-prompt"), { - "branch_suffix": branch.suffix, + "branch_suffix": branch.folio_id, "template": prompt.template if isinstance(prompt, TemplatedPrompt) else None, @@ -171,7 +166,7 @@ def generate_draft( # noqa: PLR0913 if delta and accept.value >= Accept.CHECKOUT.value: delta.apply() if accept.value >= Accept.FINALIZE.value: - self.finalize_draft(delete=accept == Accept.NO_REGRETS, sync=sync) + self.finalize_draft(delete=accept == Accept.NO_REGRETS) return Draft(str(branch)) def _prepare_prompt( @@ -210,26 +205,24 @@ def _generate_change( title = action.title if not title: title = _default_title(goal.prompt) - commit = self._repo.index.commit( + commit = self._repo.create_commit( f"draft! {title}\n\n{goal.prompt}", skip_hooks=True, ) return _Change( - commit.hexsha, timedelta(seconds=walltime), action, self._repo + commit.sha, timedelta(seconds=walltime), action, self._repo ) - def finalize_draft( - self, *, delete: bool = False, sync: bool = False - ) -> Draft: + def finalize_draft(self, *, delete: bool = False) -> Draft: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") - self._stage_repo(sync) + self._stage_repo() with self._store.cursor() as cursor: rows = cursor.execute( - sql("get-branch-by-suffix"), {"suffix": branch.suffix} + sql("get-branch-by-suffix"), {"suffix": branch.folio_id} ) if not rows: raise RuntimeError("Unrecognized draft branch") @@ -238,25 +231,25 @@ def finalize_draft( # We do a small dance to move back to the original branch, keeping the # draft branch untouched. See https://stackoverflow.com/a/15993574 for # the inspiration. - self._repo.git.checkout(detach=True) - self._repo.git.reset("-N", origin_branch) - self._repo.git.checkout(origin_branch) + self._repo.git("checkout", "--detach") + self._repo.git("reset", "-N", origin_branch) + self._repo.git("checkout", origin_branch) if delete: - self._repo.git.branch("-D", branch.name) + self._repo.git("branch", "-D", branch.name) _logger.debug("Deleted branch %s.", branch) _logger.info("Exited %s.", branch) return Draft(branch.name) - def _create_branch(self, sync: bool) -> _Branch: - if self._repo.head.is_detached: + def _create_branch(self) -> _Branch: + if self._repo.active_branch() is None: raise RuntimeError("No currently active branch") - origin_branch = self._repo.active_branch.name - origin_sha = self._repo.commit().hexsha + origin_branch = self._repo.active_branch() + origin_sha = self._repo.head_commit().sha - self._repo.git.checkout(detach=True) - self._stage_repo(sync) + self._repo.git("checkout", "--detach") + self._stage_repo() suffix = _Branch.new_suffix() with self._store.cursor() as cursor: @@ -264,23 +257,21 @@ def _create_branch(self, sync: bool) -> _Branch: sql("add-branch"), { "suffix": suffix, - "repo_path": self._repo.working_dir, + "repo_path": str(self._repo.working_dir), "origin_branch": origin_branch, "origin_sha": origin_sha, }, ) branch = _Branch(suffix) - branch_ref = self._repo.create_head(branch.name) - self._repo.git.checkout(branch_ref) + self._repo.checkout_new_branch(branch.name) return branch - def _stage_repo(self, sync: bool) -> str | None: - self._repo.git.add(all=True) - if not sync or not self._repo.is_dirty(untracked_files=True): + def _stage_repo(self) -> Commit | None: + self._repo.git("add", "--all") + if not self._repo.has_staged_changes(): return None - ref = self._repo.index.commit("draft! sync") - return ref.hexsha + return self._repo.create_commit("draft! sync") def history_table(self, branch_name: str | None = None) -> Table: path = self._repo.working_dir @@ -290,13 +281,13 @@ def history_table(self, branch_name: str | None = None) -> Table: results = cursor.execute( sql("list-prompts"), { - "repo_path": path, - "branch_suffix": branch.suffix, + "repo_path": str(path), + "branch_suffix": branch.folio_id, }, ) else: results = cursor.execute( - sql("list-drafts"), {"repo_path": path} + sql("list-drafts"), {"repo_path": str(path)} ) return Table.from_cursor(results) @@ -309,8 +300,8 @@ def latest_draft_prompt(self) -> str | None: result = cursor.execute( sql("get-latest-prompt"), { - "repo_path": self._repo.working_dir, - "branch_suffix": branch.suffix, + "repo_path": str(self._repo.working_dir), + "branch_suffix": branch.folio_id, }, ).fetchone() return result[0] if result else None @@ -326,10 +317,10 @@ class _Change: commit: _CommitSHA walltime: timedelta action: Action - repo: git.Repo = dataclasses.field(repr=False) + repo: Repo = dataclasses.field(repr=False) def delta(self) -> _Delta | None: - diff = self.repo.git.diff_tree(self.commit, patch=True) + diff = self.repo.git("diff-tree", "--patch", self.commit).stdout return _Delta(diff, self.repo) if diff else None @@ -338,23 +329,20 @@ class _Delta: """A change's effects, guaranteed non-empty""" diff: str - repo: git.Repo = dataclasses.field(repr=False) + repo: Repo = dataclasses.field(repr=False) def apply(self) -> None: # For patch applcation to work as expected (adding conflict markers as # needed), files in the patch must exist in the index. - self.repo.git.add(all=True) - with tempfile.TemporaryFile() as temp: - temp.write(self.diff.encode("utf8")) - temp.seek(0) - try: - self.repo.git.apply("--3way", "-", istream=temp) - except git.CommandError as exc: - if "with conflicts" in exc.stderr: - raise ConflictError() - raise exc - finally: - self.repo.git.reset() + self.repo.git("add", "--all") + git = self.repo.git( + "apply", "--3way", "-", stdin=self.diff, expect_codes=() + ) + if "with conflicts" in git.stderr: + raise ConflictError() + if git.code != 0: + raise NotImplementedError() # TODO: Raise better error + self.repo.git("reset") class ConflictError(Exception): diff --git a/src/git_draft/git.py b/src/git_draft/git.py new file mode 100644 index 0000000..3fe80d0 --- /dev/null +++ b/src/git_draft/git.py @@ -0,0 +1,107 @@ +"""Git wrapper""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import logging +from pathlib import Path +import subprocess +from typing import Self + + +_logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class Commit: + """Commit newtype""" + + sha: str + + def __str__(self) -> str: + return self.sha + + +class Repo: + """Git repository""" + + def __init__(self, working_dir: Path) -> None: + self.working_dir = working_dir + + @classmethod + def enclosing(cls, path: Path) -> Self: + git = Git.run("-C", str(path), "rev-parse", "--show-toplevel") + return cls(Path(git.stdout)) + + def git( + self, + cmd: str, + *args: str, + stdin: str | None = None, + expect_codes: Sequence[int] = (0,), + ) -> Git: + return Git.run( + "-C", + str(self.working_dir), + cmd, + *args, + stdin=stdin, + expect_codes=expect_codes, + ) + + def active_branch(self) -> str | None: + return self.git("branch", "--show-current").stdout or None + + def checkout_new_branch(self, name: str) -> None: + self.git("checkout", "-b", name) + + def has_staged_changes(self) -> bool: + git = self.git("diff", "--quiet", "--staged", expect_codes=()) + return git.code != 0 + + def head_commit(self) -> Commit: + sha = self.git("rev-parse", "HEAD").stdout + return Commit(sha) + + def create_commit(self, message: str, skip_hooks: bool = False) -> Commit: + args = ["commit", "--allow-empty", "-m", message] + if skip_hooks: + args.append("--no-verify") + self.git(*args) + return self.head_commit() + + +@dataclasses.dataclass(frozen=True) +class Git: + """Git command execution result""" + + code: int + stdout: str + stderr: str + + @classmethod + def run( + cls, + *args: str, + stdin: str | None = None, + executable: str = "git", + expect_codes: Sequence[int] = (0,), + ) -> Self: + _logger.debug("Running git command. [args=%r]", args) + popen = subprocess.Popen( + [executable, *args], + encoding="utf8", + stdin=None if stdin is None else subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = popen.communicate(input=stdin) + code = popen.returncode + if expect_codes and code not in expect_codes: + raise GitError(f"Git command failed with code {code}: {stderr}") + return cls(code, stdout.rstrip(), stderr.rstrip()) + + +class GitError(Exception): + """Git command execution error""" diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 1b10f7c..f38c957 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -8,7 +8,7 @@ import tempfile from typing import Protocol, override -import git +from .git import GitError, Repo _logger = logging.getLogger(__name__) @@ -136,7 +136,7 @@ class StagingToolbox(Toolbox): """ def __init__( - self, repo: git.Repo, visitors: Sequence[ToolVisitor] | None = None + self, repo: Repo, visitors: Sequence[ToolVisitor] | None = None ) -> None: super().__init__(visitors) self._repo = repo @@ -145,12 +145,15 @@ def __init__( @override def _list(self) -> Sequence[PurePosixPath]: # Show staged files. - return self._repo.git.ls_files().splitlines() + return [ + PurePosixPath(p) + for p in self._repo.git("ls-files").stdout.splitlines() + ] @override def _read(self, path: PurePosixPath) -> str: # Read the file from the index. - return self._repo.git.show(f":{path}") + return self._repo.git("show", f":{path}").stdout @override def _write(self, path: PurePosixPath, contents: str) -> None: @@ -160,17 +163,19 @@ def _write(self, path: PurePosixPath, contents: str) -> None: with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: temp.write(contents.encode("utf8")) temp.close() - sha = self._repo.git.hash_object("-w", temp.name, path=path) + sha = self._repo.git( + "hash-object", "-w", temp.name, "--path", str(path) + ).stdout mode = 644 # TODO: Read from original file if it exists. - self._repo.git.update_index( - f"{mode},{sha},{path}", add=True, cacheinfo=True + self._repo.git( + "update-index", "--add", "--cacheinfo", f"{mode},{sha},{path}" ) @override def _delete(self, path: PurePosixPath) -> bool: try: - self._repo.git.rm("--", str(path), cached=True) - except git.GitCommandError as err: + self._repo.git("rm", "--cached", "--", str(path)) + except GitError as err: _logger.warning("Failed to delete file. [err=%r]", err) return False else: @@ -179,12 +184,12 @@ def _delete(self, path: PurePosixPath) -> bool: def trim_index(self) -> None: """Unstage any files which have not been written to""" - diff = self._repo.git.diff(name_only=True, cached=True) + git = self._repo.git("diff", "--name-only", "--cached") untouched = [ path - for path in diff.splitlines() + for path in git.stdout.splitlines() if path and path not in self._updated ] if untouched: - self._repo.git.reset("--", *untouched) + self._repo.git("reset", "--", *untouched) _logger.debug("Trimmed index. [reset_paths=%s]", untouched) diff --git a/tests/git_draft/conftest.py b/tests/git_draft/conftest.py index 3a5f230..1841bb4 100644 --- a/tests/git_draft/conftest.py +++ b/tests/git_draft/conftest.py @@ -1,14 +1,18 @@ from collections.abc import Iterator from pathlib import Path -import git import pytest +from git_draft.git import Git, Repo + @pytest.fixture -def repo(tmp_path: Path) -> Iterator[git.Repo]: - repo = git.Repo.init(str(tmp_path / "repo"), initial_branch="main") - repo.index.commit("init") +def repo(tmp_path: Path) -> Iterator[Repo]: + path = tmp_path / "repo" + path.mkdir() + Git.run("-C", str(path), "init", "-b", "main") + repo = Repo.enclosing(path) + repo.git("commit", "-m", "init", "--allow-empty") yield repo diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 329e50b..5ac7adf 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -3,11 +3,11 @@ from pathlib import Path, PurePosixPath from typing import Self -import git import pytest from git_draft.bots import Action, Bot, Goal, Toolbox import git_draft.drafter as sut +from git_draft.git import Commit, Repo from git_draft.prompt import TemplatedPrompt from git_draft.store import Store @@ -41,7 +41,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: class TestDrafter: @pytest.fixture(autouse=True) - def setup(self, repo: git.Repo) -> None: + def setup(self, repo: Repo) -> None: self._repo = repo self._drafter = sut.Drafter(Store.in_memory(), repo) @@ -62,17 +62,18 @@ def _write(self, name: str, contents="") -> None: def _delete(self, name: str) -> None: os.remove(self._path(name)) - def _commits(self, ref: str | None = None) -> Sequence[git.Commit]: - return list(self._repo.iter_commits(rev=ref)) + def _commits(self, ref: str | None = None) -> Sequence[Commit]: + git = self._repo.git("log", "--pretty=format:%H", ref or "HEAD") + return [Commit(i) for i in git.stdout.splitlines()] def _commit_files(self, ref: str) -> frozenset[str]: - text = self._repo.git.diff_tree( - ref, no_commit_id=True, name_only=True, relative=True + git = self._repo.git( + "diff-tree", ref, "--no-commit-id", "--name-only", "--relative" ) - return frozenset(text.splitlines()) + return frozenset(git.stdout.splitlines()) def _checkout(self) -> None: - self._repo.git.checkout("--", ".") + self._repo.git("checkout", "--", ".") def test_generate_draft(self) -> None: self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"})) @@ -97,7 +98,7 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: assert self._commit_files("HEAD") == set(["p2", "p3"]) def test_generate_outside_branch(self) -> None: - self._repo.git.checkout("--detach") + self._repo.git("checkout", "--detach") with pytest.raises(RuntimeError): self._drafter.generate_draft("ok", _SimpleBot.noop()) @@ -107,28 +108,24 @@ def test_generate_empty_prompt(self) -> None: def test_generate_dirty_index_no_reset(self) -> None: self._write("log") - self._repo.git.add(all=True) + self._repo.git("add", "--all") with pytest.raises(ValueError): self._drafter.generate_draft("hi", _SimpleBot.noop()) def test_generate_dirty_index_reset_sync(self) -> None: self._write("log", "11") - self._repo.git.add(all=True) - self._drafter.generate_draft( - "hi", _SimpleBot.prompt(), reset=True, sync=True - ) + self._repo.git("add", "--all") + self._drafter.generate_draft("hi", _SimpleBot.prompt(), reset=True) assert self._read("log") == "11" assert not self._path("PROMPT").exists() - self._repo.git.checkout(".") + self._repo.git("checkout", ".") assert self._read("PROMPT") == "hi" assert len(self._commits()) == 3 # init, sync, prompt def test_generate_clean_index_sync(self) -> None: prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) - self._drafter.generate_draft( - prompt, _SimpleBot({"p1": "abc"}), sync=True - ) - self._repo.git.checkout(".") + self._drafter.generate_draft(prompt, _SimpleBot({"p1": "abc"})) + self._repo.git("checkout", ".") assert "abc" in (self._read("p1") or "") assert len(self._commits()) == 2 # sync, prompt @@ -136,20 +133,20 @@ def test_generate_reuse_branch(self) -> None: bot = _SimpleBot({"prompt": lambda goal: goal.prompt}) self._drafter.generate_draft("prompt1", bot) self._drafter.generate_draft("prompt2", bot) - self._repo.git.checkout(".") + self._repo.git("checkout", ".") assert self._read("prompt") == "prompt2" - assert len(self._commits()) == 3 # init, prompt, prompt + assert len(self._commits()) == 4 # init, sync, prompt, prompt def test_generate_reuse_branch_sync(self) -> None: bot = _SimpleBot({"p1": "A"}) self._drafter.generate_draft("prompt1", bot) - self._drafter.generate_draft("prompt2", bot, sync=True) + self._drafter.generate_draft("prompt2", bot) assert len(self._commits()) == 4 # init, prompt, sync, prompt def test_generate_noop(self) -> None: self._write("unrelated", "a") self._drafter.generate_draft("prompt", _SimpleBot.noop()) - assert len(self._commits()) == 2 # init, prompt + assert len(self._commits()) == 3 # init, sync, prompt assert not self._commit_files("HEAD") def test_generate_accept_checkout(self) -> None: @@ -160,20 +157,20 @@ def test_generate_accept_checkout(self) -> None: "hello", _SimpleBot({"p1": "C", "p3": "D", "p4": None}), accept=sut.Accept.CHECKOUT, - sync=True, ) assert self._read("p1") == "C" assert self._read("p2") == "B" assert self._read("p3") == "D" assert self._read("p4") is None + @pytest.mark.skip(reason="conflict resolution in flux") def test_generate_accept_checkout_conflict(self) -> None: self._write("p1", "A") with pytest.raises(sut.ConflictError): self._drafter.generate_draft( "hello", _SimpleBot({"p1": "B", "p2": "C"}), - accept=sut.Accept.CHECKOUT + accept=sut.Accept.CHECKOUT, ) assert """<<<<<<< ours\nA""" in (self._read("p1") or "") assert self._read("p2") == "C" @@ -187,7 +184,7 @@ def test_generate_accept_finalize(self) -> None: ) assert self._read("p1") == "A" assert self._read("p2") == "B" - assert self._repo.active_branch.name == "main" + assert self._repo.active_branch() == "main" def test_delete_unknown_file(self) -> None: self._drafter.generate_draft("hello", _SimpleBot({"p1": None})) @@ -208,11 +205,16 @@ def test_finalize_and_sync(self) -> None: accept=sut.Accept.CHECKOUT, ) self._write("PROMPT", "a2") - self._drafter.finalize_draft(sync=True) + self._drafter.finalize_draft() assert self._read("PROMPT") == "a2" commits = self._commits(draft.branch_name) assert len(commits) == 3 # init, prompt, sync - assert "sync" in commits[0].message + assert ( + "sync" + in self._repo.git( + "log", "--format=%B", "-n1", commits[0].sha + ).stdout + ) def test_history_table_empty(self) -> None: table = self._drafter.history_table() diff --git a/tests/git_draft/toolbox_test.py b/tests/git_draft/toolbox_test.py index 83b1aa4..ada65fc 100644 --- a/tests/git_draft/toolbox_test.py +++ b/tests/git_draft/toolbox_test.py @@ -1,51 +1,51 @@ from pathlib import Path, PurePosixPath -import git import pytest +from git_draft.git import GitError, Repo import git_draft.toolbox as sut class TestStagingToolbox: @pytest.fixture(autouse=True) - def setup(self, repo: git.Repo) -> None: + def setup(self, repo: Repo) -> None: self._toolbox = sut.StagingToolbox(repo) - def test_list_files(self, repo: git.Repo) -> None: + def test_list_files(self, repo: Repo) -> None: assert self._toolbox.list_files() == [] names = set(["one.txt", "two.txt"]) for name in names: with Path(repo.working_dir, name).open("w") as f: f.write("ok") - repo.git.add(all=True) - assert set(self._toolbox.list_files()) == names + repo.git("add", "--all") + assert set(str(p) for p in self._toolbox.list_files()) == names - def test_read_file(self, repo: git.Repo) -> None: + def test_read_file(self, repo: Repo) -> None: with Path(repo.working_dir, "one").open("w") as f: f.write("ok") path = PurePosixPath("one") - with pytest.raises(git.GitCommandError): - assert self._toolbox.read_file(path) == "" + with pytest.raises(GitError): + self._toolbox.read_file(path) - repo.git.add(all=True) + repo.git("add", "--all") assert self._toolbox.read_file(path) == "ok" - def test_write_file(self, repo: git.Repo) -> None: + def test_write_file(self, repo: Repo) -> None: self._toolbox.write_file(PurePosixPath("one"), "hi") path = Path(repo.working_dir, "one") assert not path.exists() - repo.git.checkout_index(all=True) + repo.git("checkout-index", "--all") with path.open() as f: assert f.read() == "hi" - def test_rename_file(self, repo: git.Repo) -> None: + def test_rename_file(self, repo: Repo) -> None: self._toolbox.write_file(PurePosixPath("one"), "hi") self._toolbox.rename_file(PurePosixPath("one"), PurePosixPath("two")) - repo.git.checkout_index(all=True) + repo.git("checkout-index", "--all") assert not Path(repo.working_dir, "one").exists() with Path(repo.working_dir, "two").open() as f: assert f.read() == "hi"