diff --git a/docs/git-draft.1.adoc b/docs/git-draft.1.adoc index a0470d8..9b2d5d4 100644 --- a/docs/git-draft.1.adoc +++ b/docs/git-draft.1.adoc @@ -21,6 +21,7 @@ IMPORTANT: `git-draft` is WIP. git draft [options] [--new] [--accept... | --no-accept] [--bot BOT] [--edit] [TEMPLATE [VARIABLE...] | -] git draft [options] --quit +git draft [options] --events [REF] git draft [options] --templates [--json | [--edit] TEMPLATE] @@ -123,7 +124,7 @@ o draft! sync(prompt) o (main, draft/123) ---- -If merging is enabled, it have both the LLM-generated changes and manual edits as parents. +If merging is enabled, the merge commit will have both the LLM-generated changes and manual edits as parents. [source] ---- diff --git a/pyproject.toml b/pyproject.toml index 406fd46..1d0f1cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,9 @@ lines-after-imports = 2 [tool.ruff.lint.pydocstyle] convention = "google" +[tool.ruff.lint.pylint] +max-returns = 20 + [tool.ruff.lint.per-file-ignores] "__main__.py" = ["T20"] "tests/**" = ["ANN", "D", "SLF"] diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 08dd443..d6396b4 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -72,6 +72,7 @@ def callback( add_command("new", help="create a new draft from a prompt") add_command("quit", help="return to original branch") + add_command("events", help="list events") add_command("templates", short="T", help="show template information") parser.add_option( @@ -214,6 +215,10 @@ async def run() -> None: # noqa: PLR0912 PLR0915 drafter.quit_folio() case "quit": drafter.quit_folio() + case "events": + draft_id = args[0] if args else None + for elem in drafter.list_draft_events(draft_id): + print(elem) case "templates": if args: name = args[0] diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 1603281..45f7f71 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence import dataclasses -import datetime +from datetime import datetime import itertools import logging import os @@ -103,13 +103,22 @@ def reindent(s: str, prefix: str = "", width: int = 0) -> str: ) +def tagged(text: str, /, **kwargs) -> str: + if kwargs: + tags = [ + f"{key}={val}" for key, val in kwargs.items() if val is not None + ] + text = f"{text} [{', '.join(tags)}]" if tags else text + return reindent(text) + + def qualified_class_name(cls: type) -> str: name = cls.__qualname__ return f"{cls.__module__}.{name}" if cls.__module__ else name -def now() -> datetime.datetime: - return datetime.datetime.now().astimezone() +def now() -> datetime: + return datetime.now().astimezone() class Table: diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 8c76fa2..62fe134 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -2,9 +2,9 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence import dataclasses -from datetime import timedelta +from datetime import datetime, timedelta import logging import re import textwrap @@ -12,10 +12,17 @@ from typing import Literal from .bots import ActionSummary, Bot, Goal -from .common import qualified_class_name, reindent +from .common import ( + UnreachableError, + now, + qualified_class_name, + reindent, + tagged, +) from .events import ( Event, EventConsumer, + event_decoders, event_encoder, feedback_events, worktree_events, @@ -46,8 +53,17 @@ def ref(self) -> str: return _draft_ref(self.folio.id, self.seqno) +_DRAFT_REF_PREFIX = "refs/drafts/" + + def _draft_ref(folio_id: int, suffix: int | str) -> str: - return f"refs/drafts/{folio_id}/{suffix}" + return f"{_DRAFT_REF_PREFIX}{folio_id}/{suffix}" + + +def _parse_draft_ref(ref: str) -> tuple[int, int | None]: + ref = ref.removeprefix(_DRAFT_REF_PREFIX) + parts = ref.split("/") + return int(parts[0]), int(parts[1]) if len(parts) > 1 else None _FOLIO_BRANCH_NAMESPACE = "draft" @@ -70,7 +86,7 @@ def upstream_branch_name(self) -> str: return self.branch_name() + _FOLIO_UPSTREAM_BRANCH_SUFFIX -def _active_folio(repo: Repo) -> Folio | None: +def _maybe_active_folio(repo: Repo) -> Folio | None: active_branch = repo.active_branch() if not active_branch: return None @@ -80,6 +96,13 @@ def _active_folio(repo: Repo) -> Folio | None: return Folio(int(match[1])) +def _active_folio(repo: Repo) -> Folio: + folio = _maybe_active_folio(repo) + if not folio: + raise RuntimeError("Not currently on a draft branch") + return folio + + #: Select ort strategies. DraftMergeStrategy = Literal[ "ours", @@ -133,7 +156,7 @@ async def generate_draft( ) # Ensure that we are in a folio. - folio = _active_folio(self._repo) + folio = _maybe_active_folio(self._repo) if not folio: folio = self._create_folio() with self._store.cursor() as cursor: @@ -149,7 +172,7 @@ async def generate_draft( # Run the bot to generate the change. event_recorder = _EventRecorder(self._progress) with self._progress.spinner("Running bot...") as spinner: - feedback = spinner.feedback() + feedback = spinner.feedback(event_recorder) change = await self._generate_change( bot, Goal(prompt_contents), @@ -206,11 +229,11 @@ async def generate_draft( [ { "prompt_id": prompt_id, - "occurred_at": e.at, + "occurred_at": dt, "class": e.__class__.__name__, "data": encoder.encode(e), } - for e in event_recorder.events + for (dt, e) in event_recorder.events() ], ) spinner.update("Created draft commit.", ref=draft.ref) @@ -244,9 +267,6 @@ async def generate_draft( def quit_folio(self) -> None: folio = _active_folio(self._repo) - if not folio: - raise RuntimeError("Not currently on a draft branch") - with self._store.cursor() as cursor: rows = cursor.execute(sql("get-folio-by-id"), {"id": folio.id}) if not rows: @@ -404,7 +424,7 @@ def _commit_tree( def latest_draft_prompt(self) -> str | None: """Returns the latest prompt for the current draft""" - folio = _active_folio(self._repo) + folio = _maybe_active_folio(self._repo) if not folio: return None with self._store.cursor() as cursor: @@ -422,6 +442,27 @@ def latest_draft_prompt(self) -> str | None: prompt = "\n\n".join([prompt, reindent(question, prefix="> ")]) return prompt + def list_draft_events(self, draft_ref: str | None = None) -> Sequence[str]: + if draft_ref: + folio_id, seqno = _parse_draft_ref(draft_ref) + else: + folio = _active_folio(self._repo) + folio_id = folio.id + seqno = None + elems = [] + with self._store.cursor() as cursor: + rows = cursor.execute( + sql("list-action-events"), + {"folio_id": folio_id, "seqno": seqno}, + ) + decoders = event_decoders() + for row in rows: + occurred_at, class_name, data = row + event = decoders[class_name].decode(data) + description = _format_event(event) + elems.append(f"{occurred_at}\t{class_name}\t{description}") + return elems + @dataclasses.dataclass(frozen=True) class _Change: @@ -442,34 +483,50 @@ class _EventRecorder(EventConsumer): """ def __init__(self, progress: Progress) -> None: - self.events = list[Event]() + self._events = list[tuple[datetime, Event]]() self._progress = progress + def events(self) -> Sequence[tuple[datetime, Event]]: + return sorted(list(self._events)) + def on_event(self, event: Event) -> None: - self.events.append(event) - match event: - case worktree_events.ListFiles(_, paths): - self._progress.report("Listed files.", count=len(paths)) - case worktree_events.ReadFile(_, path, contents): - size = -1 if contents is None else len(contents) - self._progress.report(f"Read {path}.", length=size) - case worktree_events.WriteFile(_, path, contents): - size = len(contents) - self._progress.report(f"Wrote {path}.", length=size) - case worktree_events.DeleteFile(_, path): - self._progress.report(f"Deleted {path}.") - case worktree_events.RenameFile(_, src_path, dst_path): - self._progress.report(f"Renamed {src_path} to {dst_path}.") - case worktree_events.StartEditingFiles(_): - self._progress.report("Started editing files...") - case worktree_events.StopEditingFiles(_): - self._progress.report("Stopped editing files.") - case ( - feedback_events.NotifyUser(_, _) - | feedback_events.RequestUserGuidance(_, _) - | feedback_events.ReceiveUserGuidance(_, _) - ): - pass + self._events.append((now(), event)) + if formatted := _format_internal_event(event): + self._progress.report(formatted) + + +def _format_internal_event(event: Event) -> str: + match event: + case worktree_events.ListFiles(path_count): + return f"Listed {path_count} files." + case worktree_events.ReadFile(path, char_count): + return tagged(f"Read {path}.", length=char_count) + case worktree_events.WriteFile(path, char_count): + return tagged(f"Wrote {path}.", length=char_count) + case worktree_events.DeleteFile(path): + return f"Deleted {path}." + case worktree_events.RenameFile(src_path, dst_path): + return f"Renamed {src_path} to {dst_path}." + case worktree_events.StartEditingFiles(): + return "Started editing files..." + case worktree_events.StopEditingFiles(): + return "Stopped editing files." + case _: + return "" + + +def _format_event(event: Event) -> str: + if formatted := _format_internal_event(event): + return formatted + match event: + case feedback_events.NotifyUser(update): + return update + case feedback_events.RequestUserGuidance(question): + return question + case feedback_events.ReceiveUserGuidance(answer): + return answer + case _: + raise UnreachableError() def _default_title(prompt: str) -> str: diff --git a/src/git_draft/events/__init__.py b/src/git_draft/events/__init__.py index 4b5b2d5..2db1959 100644 --- a/src/git_draft/events/__init__.py +++ b/src/git_draft/events/__init__.py @@ -1,20 +1,21 @@ -"""Event package""" +"""Event definitions and (de)serializers""" +import collections +from collections.abc import Mapping from pathlib import PurePosixPath from typing import Any, Protocol import msgspec from . import feedback_events, worktree_events -from .common import events +from .common import all_events __all__ = [ "Event", "EventConsumer", - "event_decoder", + "event_decoders", "event_encoder", - "events", "feedback_events", "worktree_events", ] @@ -42,6 +43,7 @@ def on_event(self, event: Event) -> None: def event_encoder() -> msgspec.json.Encoder: + """Returns a JSON encoder for event instances""" return msgspec.json.Encoder(enc_hook=_enc_hook) @@ -50,14 +52,15 @@ def _enc_hook(obj: Any) -> Any: return str(obj) -def event_decoder() -> msgspec.json.Decoder: - """Returns a decoder for event instances +def event_decoders() -> Mapping[str, msgspec.json.Decoder]: + """Returns JSON decoders for event instances, keyed by event class name""" + return _Decoders() - It should be used as follows to get typed values: - decoder.decode(data, type=events[class_name]) - """ - return msgspec.json.Decoder(dec_hook=_dec_hook) +class _Decoders(collections.defaultdict[str, msgspec.json.Decoder]): + def __missing__(self, key: str) -> msgspec.json.Decoder: + event_class = getattr(all_events, key) + return msgspec.json.Decoder(dec_hook=_dec_hook, type=event_class) def _dec_hook(tp: type, obj: Any) -> Any: diff --git a/src/git_draft/events/common.py b/src/git_draft/events/common.py index e376d8f..ff49e1b 100644 --- a/src/git_draft/events/common.py +++ b/src/git_draft/events/common.py @@ -1,20 +1,19 @@ """Common event utilities""" -import datetime import types -from typing import Any +from typing import Any, dataclass_transform import msgspec -events = types.SimpleNamespace() +all_events = types.SimpleNamespace() +# https://discuss.python.org/t/cannot-inherit-non-frozen-dataclass-from-a-frozen-one/79273 +@dataclass_transform(field_specifiers=(msgspec.field,), frozen_default=True) class EventStruct(msgspec.Struct, frozen=True): """Base immutable structure for all event types""" - at: datetime.datetime - def __init_subclass__(cls, *args: Any, **kwargs) -> None: super().__init_subclass__(*args, **kwargs) - setattr(events, cls.__name__, cls) + setattr(all_events, cls.__name__, cls) diff --git a/src/git_draft/events/feedback_events.py b/src/git_draft/events/feedback_events.py index 2cf1247..292ef5a 100644 --- a/src/git_draft/events/feedback_events.py +++ b/src/git_draft/events/feedback_events.py @@ -3,19 +3,19 @@ from .common import EventStruct -class NotifyUser(EventStruct, frozen=True): +class NotifyUser(EventStruct): """Generic user notification""" - contents: str + update: str -class RequestUserGuidance(EventStruct, frozen=True): +class RequestUserGuidance(EventStruct): """Additional information is requested from the user""" question: str -class ReceiveUserGuidance(EventStruct, frozen=True): +class ReceiveUserGuidance(EventStruct): """Response provided by the user""" answer: str diff --git a/src/git_draft/events/worktree_events.py b/src/git_draft/events/worktree_events.py index 61486ef..b551e76 100644 --- a/src/git_draft/events/worktree_events.py +++ b/src/git_draft/events/worktree_events.py @@ -1,47 +1,46 @@ """Event types related to worktree file operations""" -from collections.abc import Sequence from pathlib import PurePosixPath from .common import EventStruct -class ListFiles(EventStruct, frozen=True): +class ListFiles(EventStruct): """All files were listed""" - paths: Sequence[PurePosixPath] + path_count: int -class ReadFile(EventStruct, frozen=True): +class ReadFile(EventStruct): """A file was read""" path: PurePosixPath - contents: str | None + char_count: int | None -class WriteFile(EventStruct, frozen=True): +class WriteFile(EventStruct): """A file was written""" path: PurePosixPath - contents: str + char_count: int -class DeleteFile(EventStruct, frozen=True): +class DeleteFile(EventStruct): """A file was deleted""" path: PurePosixPath -class RenameFile(EventStruct, frozen=True): +class RenameFile(EventStruct): """A file was renamed""" src_path: PurePosixPath dst_path: PurePosixPath -class StartEditingFiles(EventStruct, frozen=True): +class StartEditingFiles(EventStruct): """A temporary editable copy of all files was opened""" -class StopEditingFiles(EventStruct, frozen=True): +class StopEditingFiles(EventStruct): """The editable copy was closed""" diff --git a/src/git_draft/progress.py b/src/git_draft/progress.py index da469ef..42db6aa 100644 --- a/src/git_draft/progress.py +++ b/src/git_draft/progress.py @@ -9,7 +9,8 @@ import yaspin.core from .bots import UserFeedback -from .common import reindent +from .common import reindent, tagged +from .events import EventConsumer, feedback_events class Progress: @@ -46,16 +47,43 @@ def hidden(self) -> Iterator[None]: def update(self, text: str, **tags) -> None: # pragma: no cover raise NotImplementedError() - def feedback(self) -> ProgressFeedback: + def feedback(self, event_consumer: EventConsumer) -> ProgressFeedback: raise NotImplementedError() class ProgressFeedback(UserFeedback): """User feedback interface""" - def __init__(self) -> None: + def __init__(self, event_consumer: EventConsumer) -> None: + self._event_consumer = event_consumer self.pending_question: str | None = None + @override + def notify(self, update: str) -> None: + self._event_consumer.on_event(feedback_events.NotifyUser(update)) + self._notify(update) + + def _notify(self, update: str) -> None: + raise NotImplementedError() + + @override + def ask(self, question: str) -> str: + assert not self.pending_question + self._event_consumer.on_event( + feedback_events.RequestUserGuidance(question) + ) + answer = self._ask(question) + if answer is None: + self.pending_question = question + answer = _offline_answer + self._event_consumer.on_event( + feedback_events.ReceiveUserGuidance(answer) + ) + return _offline_answer + + def _ask(self, question: str) -> str | None: + raise NotImplementedError() + _offline_answer = reindent(""" I'm unable to provide feedback at this time. Perform any final changes and @@ -68,7 +96,7 @@ def __init__(self) -> None: self._spinner: _DynamicProgressSpinner | None = None def report(self, text: str, **tags) -> None: - message = f"☞ {_tagged(text, **tags)}" + message = f"☞ {tagged(text, **tags)}" if self._spinner: self._spinner.yaspin.write(message) else: @@ -77,7 +105,7 @@ def report(self, text: str, **tags) -> None: @contextlib.contextmanager def spinner(self, text: str, **tags) -> Iterator[ProgressSpinner]: assert not self._spinner - with yaspin.yaspin(text=_tagged(text, **tags)) as spinner: + with yaspin.yaspin(text=tagged(text, **tags)) as spinner: self._spinner = _DynamicProgressSpinner(spinner) try: yield self._spinner @@ -100,35 +128,35 @@ def hidden(self) -> Iterator[None]: yield def update(self, text: str, **tags) -> None: - self.yaspin.text = _tagged(text, **tags) + self.yaspin.text = tagged(text, **tags) - def feedback(self) -> ProgressFeedback: - return _DynamicProgressFeedback(self) + def feedback(self, event_consumer: EventConsumer) -> ProgressFeedback: + return _DynamicProgressFeedback(event_consumer, self) class _DynamicProgressFeedback(ProgressFeedback): - def __init__(self, spinner: _DynamicProgressSpinner) -> None: - super().__init__() + def __init__( + self, + event_consumer: EventConsumer, + spinner: _DynamicProgressSpinner, + ) -> None: + super().__init__(event_consumer) self._spinner = spinner @override - def notify(self, update: str) -> None: + def _notify(self, update: str) -> None: self._spinner.update(update) @override - def ask(self, question: str) -> str: - assert not self.pending_question + def _ask(self, question: str) -> str | None: with self._spinner.hidden(): - answer = input(question) - if answer: - return answer - self.pending_question = question - return _offline_answer + answer = input(question + " ") + return answer or None class _StaticProgress(Progress): def report(self, text: str, **tags) -> None: - print(_tagged(text, **tags)) # noqa + print(tagged(text, **tags)) # noqa @contextlib.contextmanager def spinner(self, text: str, **tags) -> Iterator[ProgressSpinner]: @@ -143,31 +171,24 @@ def __init__(self, progress: _StaticProgress) -> None: def update(self, text: str, **tags) -> None: self._progress.report(text, **tags) - def feedback(self) -> ProgressFeedback: - return _StaticProgressFeedback(self._progress) + def feedback(self, event_consumer: EventConsumer) -> ProgressFeedback: + return _StaticProgressFeedback(event_consumer, self._progress) class _StaticProgressFeedback(ProgressFeedback): - def __init__(self, progress: _StaticProgress) -> None: - super().__init__() + def __init__( + self, + event_consumer: EventConsumer, + progress: _StaticProgress, + ) -> None: + super().__init__(event_consumer) self._progress = progress @override - def notify(self, update: str) -> None: + def _notify(self, update: str) -> None: self._progress.report(update) @override - def ask(self, question: str) -> str: - assert not self.pending_question + def _ask(self, question: str) -> str | None: self._progress.report(f"Feedback requested: {question}") - self.pending_question = question return _offline_answer - - -def _tagged(text: str, /, **kwargs) -> str: - if kwargs: - tags = [ - f"{key}={val}" for key, val in kwargs.items() if val is not None - ] - text = f"{text} [{', '.join(tags)}]" if tags else text - return reindent(text) diff --git a/src/git_draft/queries/list-action-events.sql b/src/git_draft/queries/list-action-events.sql new file mode 100644 index 0000000..66db45e --- /dev/null +++ b/src/git_draft/queries/list-action-events.sql @@ -0,0 +1,7 @@ +select occurred_at, class, data + from action_events as e + join prompts as p on e.prompt_id = p.id + where + p.folio_id = :folio_id and + p.seqno = coalesce(:seqno, (select max(seqno) from prompts where folio_id = :folio_id)) + order by occurred_at; diff --git a/src/git_draft/worktrees.py b/src/git_draft/worktrees.py index 5272e99..849bd36 100644 --- a/src/git_draft/worktrees.py +++ b/src/git_draft/worktrees.py @@ -10,7 +10,7 @@ from typing import Self, override from .bots import Worktree -from .common import UnreachableError, now +from .common import UnreachableError from .events import Event, EventConsumer, worktree_events from .git import SHA, GitError, Repo, null_delimited @@ -124,7 +124,7 @@ def _dispatch(self, event: Event) -> None: @override def list_files(self) -> Sequence[PurePosixPath]: paths = self._list() - self._dispatch(worktree_events.ListFiles(now(), paths)) + self._dispatch(worktree_events.ListFiles(len(paths))) return paths @override @@ -133,17 +133,19 @@ def read_file(self, path: PurePosixPath) -> str | None: contents = self._read(path) except FileNotFoundError: contents = None - self._dispatch(worktree_events.ReadFile(now(), path, contents)) + self._dispatch( + worktree_events.ReadFile(path, len(contents) if contents else None) + ) return contents @override def write_file(self, path: PurePosixPath, contents: str) -> None: - self._dispatch(worktree_events.WriteFile(now(), path, contents)) + self._dispatch(worktree_events.WriteFile(path, len(contents))) return self._write(path, contents) @override def delete_file(self, path: PurePosixPath) -> None: - self._dispatch(worktree_events.DeleteFile(now(), path)) + self._dispatch(worktree_events.DeleteFile(path)) self._delete(path) @override @@ -153,7 +155,7 @@ def rename_file( dst_path: PurePosixPath, ) -> None: """Rename a single file""" - self._dispatch(worktree_events.RenameFile(now(), src_path, dst_path)) + self._dispatch(worktree_events.RenameFile(src_path, dst_path)) contents = self._read(src_path) self._write(dst_path, contents) self._delete(src_path) @@ -166,11 +168,11 @@ def edit_files(self) -> Iterator[Path]: All updates are synced back afterwards. Other operations should not be performed concurrently as they may be stale or lost. """ - self._dispatch(worktree_events.StartEditingFiles(now())) + self._dispatch(worktree_events.StartEditingFiles()) with self._edit() as path: yield path # TODO: Expose updated files to hook? - self._dispatch(worktree_events.StopEditingFiles(now())) + self._dispatch(worktree_events.StopEditingFiles()) def _list(self) -> Sequence[PurePosixPath]: call = self._repo.git("ls-tree", "-rz", "--name-only", self.sha()) @@ -245,18 +247,22 @@ def _update_tree(sha: SHA, updates: Sequence[_Update], repo: Repo) -> SHA: return sha blob_shas = collections.defaultdict[PurePosixPath, dict[str, str]](dict) + trees = collections.defaultdict[PurePosixPath, set[str]](set) for update in updates: match update: case _WriteBlob(path, blob_sha): blob_shas[path.parent][path.name] = blob_sha + for parent in path.parents[:-1]: + trees[parent.parent].add(parent.name) case _DeleteBlob(path): blob_shas[path.parent][path.name] = "" case _: raise UnreachableError(f"Unexpected update: {update}") - def visit_tree(sha: SHA, path: PurePosixPath) -> SHA: + def visit_old_tree(sha: SHA, path: PurePosixPath) -> SHA: old_lines = null_delimited(repo.git("ls-tree", "-z", sha).stdout) - new_blob_shas = blob_shas[path] + new_blob_shas = blob_shas.pop(path, dict()) + new_trees = trees.pop(path, set()) new_lines = list[str]() for line in old_lines: @@ -264,17 +270,23 @@ def visit_tree(sha: SHA, path: PurePosixPath) -> SHA: mode, otype, old_sha = old_prefix.split(" ") match otype: case "blob": + if name in new_trees: + raise RuntimeError(f"Not a folder: {path / name}") new_sha = new_blob_shas.pop(name, old_sha) if new_sha: new_lines.append(f"{mode} blob {new_sha}\t{name}") case "tree": - new_sha = visit_tree(old_sha, path / name) + new_trees.discard(name) + new_sha = visit_old_tree(old_sha, path / name) new_lines.append(f"040000 tree {new_sha}\t{name}") case "commit": # Submodule new_lines.append(line) case _: raise UnreachableError(f"Unexpected line: {line}") + for name in new_trees: + sha = visit_new_tree(path / name) + new_lines.append(f"040000 tree {sha}\t{name}") for name, blob_sha in new_blob_shas.items(): if blob_sha: new_lines.append(f"100644 blob {blob_sha}\t{name}") @@ -286,4 +298,16 @@ def visit_tree(sha: SHA, path: PurePosixPath) -> SHA: return repo.git("mktree", "-z", stdin="\x00".join(new_lines)).stdout - return visit_tree(sha, PurePosixPath(".")) + def visit_new_tree(path: PurePosixPath) -> SHA: + lines = list[str]() + for name in trees.pop(path, set()): + tree_sha = visit_new_tree(path / name) + lines.append(f"040000 tree {tree_sha}\t{name}") + for name, blob_sha in blob_shas.pop(path, dict()).items(): + lines.append(f"100644 blob {blob_sha}\t{name}") + return repo.git("mktree", "-z", stdin="\x00".join(lines)).stdout + + new_sha = visit_old_tree(sha, PurePosixPath(".")) + assert not blob_shas, "unprocessed blobs" + assert not trees, "unprocessed trees" + return new_sha diff --git a/tests/git_draft/conftest.py b/tests/git_draft/conftest.py index 1229f48..30bbf22 100644 --- a/tests/git_draft/conftest.py +++ b/tests/git_draft/conftest.py @@ -38,7 +38,9 @@ def read(self, name: str) -> str | None: return None def write(self, name: str, contents="") -> None: - with open(self.path(name), "w") as f: + path = self.path(name) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: f.write(contents) def delete(self, name: str) -> None: @@ -46,7 +48,7 @@ def delete(self, name: str) -> None: def flush(self, message: str = "flush") -> str: self._repo.git("add", "-A") - self._repo.git("commit", "-m", message) + self._repo.git("commit", "--allow-empty", "-m", message) return self._repo.git("rev-parse", "HEAD").stdout diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 9b1dea0..a03e78c 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -83,12 +83,14 @@ async def test_generate_draft_merge(self) -> None: self._fs.write("p1", "a") await self._drafter.generate_draft( - "hello", _SimpleBot({"p2": "b"}), merge_strategy="ignore-all-space" + "hello", + _SimpleBot({"a/p2": "b"}), + merge_strategy="ignore-all-space", ) # No sync(merge) commit since no changes happened between. assert len(self._commits()) == 4 # init, sync(prompt), prompt, merge assert self._fs.read("p1") == "a" - assert self._fs.read("p2") == "b" + assert self._fs.read("a/p2") == "b" @pytest.mark.asyncio async def test_generate_draft_merge_no_conflict(self) -> None: @@ -185,3 +187,10 @@ async def test_latest_draft_prompt(self) -> None: @pytest.mark.asyncio async def test_latest_draft_prompt_no_active_branch(self) -> None: assert self._drafter.latest_draft_prompt() is None + + @pytest.mark.asyncio + async def test_list_draft_events(self) -> None: + bot = _SimpleBot({"prompt": lambda goal: goal.prompt}) + await self._drafter.generate_draft("prompt1", bot, "theirs") + elems = self._drafter.list_draft_events() + assert len(elems) == 1 diff --git a/tests/git_draft/events/__init__.py b/tests/git_draft/events/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/git_draft/events/__init___test.py b/tests/git_draft/events/__init___test.py new file mode 100644 index 0000000..84c49d7 --- /dev/null +++ b/tests/git_draft/events/__init___test.py @@ -0,0 +1,35 @@ +from pathlib import PurePosixPath + +from msgspec import json +import pytest + +import git_draft.events as sut + + +class TestEventEncoder: + @pytest.fixture + def encoder(self): + return sut.event_encoder() + + def test_encode_event(self, encoder): + path = PurePosixPath("/some/path") + event = sut.worktree_events.DeleteFile(path) + result = encoder.encode(event) + assert result == json.encode({"path": str(path)}) + + +class TestEventDecoders: + @pytest.fixture + def decoders(self): + return sut.event_decoders() + + def test_decoder_for_known_event(self, decoders): + decoder = decoders["DeleteFile"] + path_str = "/some/path" + event = decoder.decode(json.encode({"path": path_str})) + assert isinstance(event, sut.worktree_events.DeleteFile) + assert event.path == PurePosixPath(path_str) + + def test_decoder_for_unknown_event_raises_keyerror(self, decoders): + with pytest.raises(AttributeError): + _ = decoders["NonExistentEvent"] diff --git a/tests/git_draft/work_trees_test.py b/tests/git_draft/worktrees_test.py similarity index 78% rename from tests/git_draft/work_trees_test.py rename to tests/git_draft/worktrees_test.py index be46889..21694c8 100644 --- a/tests/git_draft/work_trees_test.py +++ b/tests/git_draft/worktrees_test.py @@ -54,6 +54,27 @@ def test_write_file(self) -> None: assert self._fs.read("f1") == "aa" assert self._fs.read("f3") is None + def test_write_file_in_new_folder(self) -> None: + self._fs.write("d1/f1", "a") + sha = self._fs.flush() + + tree = sut.GitWorktree(self._repo, sha) + tree.write_file(PPP("d1/f2"), "b") # In existing directory + tree.write_file(PPP("d1/d2/f3"), "c") # In new directory + tree.write_file(PPP("d1/d2/d3/f4"), "d") # In new directory + assert tree.read_file(PPP("d1/f2")) == "b" + assert tree.read_file(PPP("d1/d2/f3")) == "c" + assert tree.read_file(PPP("d1/d2/d3/f4")) == "d" + + def test_write_folder_conflict(self) -> None: + self._fs.write("f1", "a") + sha = self._fs.flush() + + tree = sut.GitWorktree(self._repo, sha) + tree.write_file(PPP("f1/f2"), "b") + with pytest.raises(RuntimeError): + _ = tree.sha() + def test_for_working_dir_dirty(self) -> None: self._fs.write("f1", "a") self._fs.write("f2", "b")