Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import importlib.metadata
import logging
import optparse
from pathlib import PurePosixPath
import sys
from typing import Sequence

from .bots import Operation, load_bot
from .bots import load_bot
from .common import PROGRAM, Config, UnreachableError, ensure_state_home
from .drafter import Drafter
from .editor import open_editor
from .prompt import TemplatedPrompt
from .store import Store
from .toolbox import ToolVisitor


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,8 +96,24 @@ def callback(_option, _opt, _value, parser) -> None:
return parser


def print_operation(op: Operation) -> None:
print(op)
class _ToolPrinter(ToolVisitor):
def on_list_files(
self, _paths: Sequence[PurePosixPath], _reason: str | None
) -> None:
print("Listing available files...")

def on_read_file(
self, path: PurePosixPath, _contents: str | None, _reason: str | None
) -> None:
print(f"Reading {path}...")

def on_write_file(
self, path: PurePosixPath, _contents: str, _reason: str | None
) -> None:
print(f"Updated {path}.")

def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None:
print(f"Deleted {path}.")


def main() -> None:
Expand All @@ -110,7 +129,6 @@ def main() -> None:
drafter = Drafter.create(
store=Store.persistent(),
path=opts.root,
operation_hook=print_operation,
)
command = getattr(opts, "command", "generate")
if command == "generate":
Expand All @@ -133,15 +151,20 @@ def main() -> None:
else:
prompt = sys.stdin.read()

drafter.generate_draft(
prompt, bot, checkout=opts.checkout, reset=opts.reset
name = drafter.generate_draft(
prompt,
bot,
tool_visitors=[_ToolPrinter()],
checkout=opts.checkout,
reset=opts.reset,
)
print(f"Generated {name}.")
elif command == "finalize":
name = drafter.finalize_draft(delete=opts.delete)
print(f"Finalized {name}")
print(f"Finalized {name}.")
elif command == "revert":
name = drafter.revert_draft(delete=opts.delete)
print(f"Reverted {name}")
print(f"Reverted {name}.")
else:
raise UnreachableError()

Expand Down
4 changes: 1 addition & 3 deletions src/git_draft/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
import sys

from ..common import BotConfig, reindent
from ..toolbox import Operation, OperationHook, Toolbox
from ..toolbox import Toolbox
from .common import Action, Bot, Goal


__all__ = [
"Action",
"Bot",
"Goal",
"Operation",
"OperationHook",
"Toolbox",
]

Expand Down
84 changes: 61 additions & 23 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from __future__ import annotations

import dataclasses
from datetime import datetime
import json
import logging
from pathlib import PurePosixPath
import re
import textwrap
import time
from typing import Match, Sequence

import git

from .bots import Bot, Goal, OperationHook
from .common import random_id
from .bots import Bot, Goal
from .common import JSONObject, random_id
from .prompt import PromptRenderer, TemplatedPrompt
from .store import Store, sql
from .toolbox import StagingToolbox
from .toolbox import StagingToolbox, ToolVisitor


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,37 +54,26 @@ def new_suffix():
class Drafter:
"""Draft state orchestrator"""

def __init__(
self, store: Store, repo: git.Repo, hook: OperationHook | None = None
) -> None:
def __init__(self, store: Store, repo: git.Repo) -> None:
with store.cursor() as cursor:
cursor.executescript(sql("create-tables"))
self._store = store
self._repo = repo
self._operation_hook = hook

@classmethod
def create(
cls,
store: Store,
path: str | None = None,
operation_hook: OperationHook | None = None,
) -> Drafter:
return cls(
store,
git.Repo(path, search_parent_directories=True),
operation_hook,
)
def create(cls, store: Store, path: str | None = None) -> Drafter:
return cls(store, git.Repo(path, search_parent_directories=True))

def generate_draft(
self,
prompt: str | TemplatedPrompt,
bot: Bot,
tool_visitors: Sequence[ToolVisitor] | None = None,
checkout: bool = False,
reset: bool = False,
sync: bool = False,
timeout: float | None = None,
) -> None:
) -> str:
if isinstance(prompt, str) and not prompt.strip():
raise ValueError("Empty prompt")
if self._repo.is_dirty(working_tree=False):
Expand All @@ -98,7 +89,9 @@ def generate_draft(
branch = self._create_branch(sync)
_logger.debug("Created branch %s.", branch)

toolbox = StagingToolbox(self._repo, self._operation_hook)
operation_recorder = _OperationRecorder()
tool_visitors = [operation_recorder] + list(tool_visitors or [])
toolbox = StagingToolbox(self._repo, tool_visitors)
if isinstance(prompt, TemplatedPrompt):
renderer = PromptRenderer.for_toolbox(toolbox)
prompt_contents = renderer.render(prompt)
Expand All @@ -118,6 +111,7 @@ def generate_draft(
goal = Goal(prompt_contents, timeout)
action = bot.act(goal, toolbox)
end_time = time.perf_counter()
walltime = end_time - start_time

toolbox.trim_index()
title = action.title
Expand All @@ -134,7 +128,7 @@ def generate_draft(
{
"commit_sha": commit.hexsha,
"prompt_id": prompt_id,
"walltime": end_time - start_time,
"walltime": walltime,
},
)
cursor.executemany(
Expand All @@ -147,13 +141,14 @@ def generate_draft(
"details": json.dumps(o.details),
"started_at": o.start,
}
for o in toolbox.operations
for o in operation_recorder.operations
],
)
_logger.info("Generated draft.")

_logger.info("Generated draft.")
if checkout:
self._repo.git.checkout("--", ".")
return str(branch)

def finalize_draft(self, delete=False) -> str:
return self._exit_draft(revert=False, delete=delete)
Expand Down Expand Up @@ -243,5 +238,48 @@ def _changed_files(self, spec) -> Sequence[str]:
return self._repo.git.diff(spec, name_only=True).splitlines()


class _OperationRecorder(ToolVisitor):
def __init__(self) -> None:
self.operations = list[_Operation]()

def on_list_files(
self, paths: Sequence[PurePosixPath], reason: str | None
) -> None:
self._record(reason, "list_files", count=len(paths))

def on_read_file(
self, path: PurePosixPath, contents: str | None, reason: str | None
) -> None:
self._record(
reason,
"read_file",
path=str(path),
size=-1 if contents is None else len(contents),
)

def on_write_file(
self, path: PurePosixPath, contents: str, reason: str | None
) -> None:
self._record(reason, "write_file", path=str(path), size=len(contents))

def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
self._record(reason, "delete_file", path=str(path))

def _record(self, reason: str | None, tool: str, **kwargs) -> None:
self.operations.append(
_Operation(
tool=tool, details=kwargs, reason=reason, start=datetime.now()
)
)


@dataclasses.dataclass(frozen=True)
class _Operation:
tool: str
details: JSONObject
reason: str | None
start: datetime


def _default_title(prompt: str) -> str:
return textwrap.shorten(prompt, break_on_hyphens=False, width=72)
Loading