Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/git-draft.1.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ IMPORTANT: `git-draft` is WIP.
git draft [options] [--new] [--accept... | --no-accept] [--bot BOT]
[--edit] [TEMPLATE [VARIABLE...] | -]
git draft [options] --quit
git draft [options] --events [REF]
git draft [options] --templates [--json | [--edit] TEMPLATE]


Expand Down Expand Up @@ -123,7 +124,7 @@ o draft! sync(prompt)
o <some commit> (main, draft/123)
----

If merging is enabled, it have both the LLM-generated changes and manual edits as parents.
If merging is enabled, the merge commit will have both the LLM-generated changes and manual edits as parents.

[source]
----
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ lines-after-imports = 2
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.pylint]
max-returns = 20

[tool.ruff.lint.per-file-ignores]
"__main__.py" = ["T20"]
"tests/**" = ["ANN", "D", "SLF"]
5 changes: 5 additions & 0 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def callback(

add_command("new", help="create a new draft from a prompt")
add_command("quit", help="return to original branch")
add_command("events", help="list events")
add_command("templates", short="T", help="show template information")

parser.add_option(
Expand Down Expand Up @@ -214,6 +215,10 @@ async def run() -> None: # noqa: PLR0912 PLR0915
drafter.quit_folio()
case "quit":
drafter.quit_folio()
case "events":
draft_id = args[0] if args else None
for elem in drafter.list_draft_events(draft_id):
print(elem)
case "templates":
if args:
name = args[0]
Expand Down
15 changes: 12 additions & 3 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Mapping, Sequence
import dataclasses
import datetime
from datetime import datetime
import itertools
import logging
import os
Expand Down Expand Up @@ -103,13 +103,22 @@ def reindent(s: str, prefix: str = "", width: int = 0) -> str:
)


def tagged(text: str, /, **kwargs) -> str:
if kwargs:
tags = [
f"{key}={val}" for key, val in kwargs.items() if val is not None
]
text = f"{text} [{', '.join(tags)}]" if tags else text
return reindent(text)


def qualified_class_name(cls: type) -> str:
name = cls.__qualname__
return f"{cls.__module__}.{name}" if cls.__module__ else name


def now() -> datetime.datetime:
return datetime.datetime.now().astimezone()
def now() -> datetime:
return datetime.now().astimezone()


class Table:
Expand Down
133 changes: 95 additions & 38 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Sequence
import dataclasses
from datetime import timedelta
from datetime import datetime, timedelta
import logging
import re
import textwrap
import time
from typing import Literal

from .bots import ActionSummary, Bot, Goal
from .common import qualified_class_name, reindent
from .common import (
UnreachableError,
now,
qualified_class_name,
reindent,
tagged,
)
from .events import (
Event,
EventConsumer,
event_decoders,
event_encoder,
feedback_events,
worktree_events,
Expand Down Expand Up @@ -46,8 +53,17 @@ def ref(self) -> str:
return _draft_ref(self.folio.id, self.seqno)


_DRAFT_REF_PREFIX = "refs/drafts/"


def _draft_ref(folio_id: int, suffix: int | str) -> str:
return f"refs/drafts/{folio_id}/{suffix}"
return f"{_DRAFT_REF_PREFIX}{folio_id}/{suffix}"


def _parse_draft_ref(ref: str) -> tuple[int, int | None]:
ref = ref.removeprefix(_DRAFT_REF_PREFIX)
parts = ref.split("/")
return int(parts[0]), int(parts[1]) if len(parts) > 1 else None


_FOLIO_BRANCH_NAMESPACE = "draft"
Expand All @@ -70,7 +86,7 @@ def upstream_branch_name(self) -> str:
return self.branch_name() + _FOLIO_UPSTREAM_BRANCH_SUFFIX


def _active_folio(repo: Repo) -> Folio | None:
def _maybe_active_folio(repo: Repo) -> Folio | None:
active_branch = repo.active_branch()
if not active_branch:
return None
Expand All @@ -80,6 +96,13 @@ def _active_folio(repo: Repo) -> Folio | None:
return Folio(int(match[1]))


def _active_folio(repo: Repo) -> Folio:
folio = _maybe_active_folio(repo)
if not folio:
raise RuntimeError("Not currently on a draft branch")
return folio


#: Select ort strategies.
DraftMergeStrategy = Literal[
"ours",
Expand Down Expand Up @@ -133,7 +156,7 @@ async def generate_draft(
)

# Ensure that we are in a folio.
folio = _active_folio(self._repo)
folio = _maybe_active_folio(self._repo)
if not folio:
folio = self._create_folio()
with self._store.cursor() as cursor:
Expand All @@ -149,7 +172,7 @@ async def generate_draft(
# Run the bot to generate the change.
event_recorder = _EventRecorder(self._progress)
with self._progress.spinner("Running bot...") as spinner:
feedback = spinner.feedback()
feedback = spinner.feedback(event_recorder)
change = await self._generate_change(
bot,
Goal(prompt_contents),
Expand Down Expand Up @@ -206,11 +229,11 @@ async def generate_draft(
[
{
"prompt_id": prompt_id,
"occurred_at": e.at,
"occurred_at": dt,
"class": e.__class__.__name__,
"data": encoder.encode(e),
}
for e in event_recorder.events
for (dt, e) in event_recorder.events()
],
)
spinner.update("Created draft commit.", ref=draft.ref)
Expand Down Expand Up @@ -244,9 +267,6 @@ async def generate_draft(

def quit_folio(self) -> None:
folio = _active_folio(self._repo)
if not folio:
raise RuntimeError("Not currently on a draft branch")

with self._store.cursor() as cursor:
rows = cursor.execute(sql("get-folio-by-id"), {"id": folio.id})
if not rows:
Expand Down Expand Up @@ -404,7 +424,7 @@ def _commit_tree(

def latest_draft_prompt(self) -> str | None:
"""Returns the latest prompt for the current draft"""
folio = _active_folio(self._repo)
folio = _maybe_active_folio(self._repo)
if not folio:
return None
with self._store.cursor() as cursor:
Expand All @@ -422,6 +442,27 @@ def latest_draft_prompt(self) -> str | None:
prompt = "\n\n".join([prompt, reindent(question, prefix="> ")])
return prompt

def list_draft_events(self, draft_ref: str | None = None) -> Sequence[str]:
if draft_ref:
folio_id, seqno = _parse_draft_ref(draft_ref)
else:
folio = _active_folio(self._repo)
folio_id = folio.id
seqno = None
elems = []
with self._store.cursor() as cursor:
rows = cursor.execute(
sql("list-action-events"),
{"folio_id": folio_id, "seqno": seqno},
)
decoders = event_decoders()
for row in rows:
occurred_at, class_name, data = row
event = decoders[class_name].decode(data)
description = _format_event(event)
elems.append(f"{occurred_at}\t{class_name}\t{description}")
return elems


@dataclasses.dataclass(frozen=True)
class _Change:
Expand All @@ -442,34 +483,50 @@ class _EventRecorder(EventConsumer):
"""

def __init__(self, progress: Progress) -> None:
self.events = list[Event]()
self._events = list[tuple[datetime, Event]]()
self._progress = progress

def events(self) -> Sequence[tuple[datetime, Event]]:
return sorted(list(self._events))

def on_event(self, event: Event) -> None:
self.events.append(event)
match event:
case worktree_events.ListFiles(_, paths):
self._progress.report("Listed files.", count=len(paths))
case worktree_events.ReadFile(_, path, contents):
size = -1 if contents is None else len(contents)
self._progress.report(f"Read {path}.", length=size)
case worktree_events.WriteFile(_, path, contents):
size = len(contents)
self._progress.report(f"Wrote {path}.", length=size)
case worktree_events.DeleteFile(_, path):
self._progress.report(f"Deleted {path}.")
case worktree_events.RenameFile(_, src_path, dst_path):
self._progress.report(f"Renamed {src_path} to {dst_path}.")
case worktree_events.StartEditingFiles(_):
self._progress.report("Started editing files...")
case worktree_events.StopEditingFiles(_):
self._progress.report("Stopped editing files.")
case (
feedback_events.NotifyUser(_, _)
| feedback_events.RequestUserGuidance(_, _)
| feedback_events.ReceiveUserGuidance(_, _)
):
pass
self._events.append((now(), event))
if formatted := _format_internal_event(event):
self._progress.report(formatted)


def _format_internal_event(event: Event) -> str:
match event:
case worktree_events.ListFiles(path_count):
return f"Listed {path_count} files."
case worktree_events.ReadFile(path, char_count):
return tagged(f"Read {path}.", length=char_count)
case worktree_events.WriteFile(path, char_count):
return tagged(f"Wrote {path}.", length=char_count)
case worktree_events.DeleteFile(path):
return f"Deleted {path}."
case worktree_events.RenameFile(src_path, dst_path):
return f"Renamed {src_path} to {dst_path}."
case worktree_events.StartEditingFiles():
return "Started editing files..."
case worktree_events.StopEditingFiles():
return "Stopped editing files."
case _:
return ""


def _format_event(event: Event) -> str:
if formatted := _format_internal_event(event):
return formatted
match event:
case feedback_events.NotifyUser(update):
return update
case feedback_events.RequestUserGuidance(question):
return question
case feedback_events.ReceiveUserGuidance(answer):
return answer
case _:
raise UnreachableError()


def _default_title(prompt: str) -> str:
Expand Down
23 changes: 13 additions & 10 deletions src/git_draft/events/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
"""Event package"""
"""Event definitions and (de)serializers"""

import collections
from collections.abc import Mapping
from pathlib import PurePosixPath
from typing import Any, Protocol

import msgspec

from . import feedback_events, worktree_events
from .common import events
from .common import all_events


__all__ = [
"Event",
"EventConsumer",
"event_decoder",
"event_decoders",
"event_encoder",
"events",
"feedback_events",
"worktree_events",
]
Expand Down Expand Up @@ -42,6 +43,7 @@ def on_event(self, event: Event) -> None:


def event_encoder() -> msgspec.json.Encoder:
"""Returns a JSON encoder for event instances"""
return msgspec.json.Encoder(enc_hook=_enc_hook)


Expand All @@ -50,14 +52,15 @@ def _enc_hook(obj: Any) -> Any:
return str(obj)


def event_decoder() -> msgspec.json.Decoder:
"""Returns a decoder for event instances
def event_decoders() -> Mapping[str, msgspec.json.Decoder]:
"""Returns JSON decoders for event instances, keyed by event class name"""
return _Decoders()

It should be used as follows to get typed values:

decoder.decode(data, type=events[class_name])
"""
return msgspec.json.Decoder(dec_hook=_dec_hook)
class _Decoders(collections.defaultdict[str, msgspec.json.Decoder]):
def __missing__(self, key: str) -> msgspec.json.Decoder:
event_class = getattr(all_events, key)
return msgspec.json.Decoder(dec_hook=_dec_hook, type=event_class)


def _dec_hook(tp: type, obj: Any) -> Any:
Expand Down
11 changes: 5 additions & 6 deletions src/git_draft/events/common.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
"""Common event utilities"""

import datetime
import types
from typing import Any
from typing import Any, dataclass_transform

import msgspec


events = types.SimpleNamespace()
all_events = types.SimpleNamespace()


# https://discuss.python.org/t/cannot-inherit-non-frozen-dataclass-from-a-frozen-one/79273
@dataclass_transform(field_specifiers=(msgspec.field,), frozen_default=True)
class EventStruct(msgspec.Struct, frozen=True):
"""Base immutable structure for all event types"""

at: datetime.datetime

def __init_subclass__(cls, *args: Any, **kwargs) -> None:
super().__init_subclass__(*args, **kwargs)
setattr(events, cls.__name__, cls)
setattr(all_events, cls.__name__, cls)
Loading