Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

## Highlights

* Concurrent edits.
* Concurrent edits. By default `git-draft` does not touch the working directory.
* Customizable prompt templates.
* Extensible bot API.

Expand All @@ -16,3 +16,15 @@
```sh
pipx install git-draft[openai]
```


## Next steps

* Mechanism for reporting feedback from a bot, and possibly allowing user to
interactively respond.
* Add configuration option to auto sync and `--no-sync` flag. Similar to reset.
* Add "amend" commit when finalizing. This could be useful training data,
showing what the bot did not get right.
* Convenience functionality for simple cases: checkout option which applies the
changes, and finalizes the draft if specified multiple times. For example `git
draft -cc add-test symbol=foo`
3 changes: 1 addition & 2 deletions docs/git-draft.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ IMPORTANT: `git-draft` is WIP.
== Synopsis

[verse]
git draft [options] [--generate] [--bot BOT] [--edit] [--reset | --no-reset]
[--sync] [TEMPLATE [VARIABLE...]]
git draft [options] [--generate] [--bot BOT] [--edit] [--reset | --no-reset] [--sync] [TEMPLATE [VARIABLE...]]
git draft [options] --finalize [--clean | --revert] [--delete]
git draft [options] --show-drafts [--json]
git draft [options] --show-prompts [--json] [PROMPT]
Expand Down
25 changes: 16 additions & 9 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,20 @@ def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None:
print(f"Deleted {path}.")


def edit(text: str | None, path: Path | None) -> str | None:
def edit(path: Path, text: str | None = None) -> str | None:
if sys.stdin.isatty():
return open_editor(text or "", path)
else:
if path and text is not None:
if text is not None:
with open(path, "w") as f:
f.write(text)
print(path)
return None


_PROMPT_PLACEHOLDER = "Enter your prompt here..."


def main() -> None:
config = Config.load()
(opts, args) = new_parser().parse_args()
Expand All @@ -177,23 +180,27 @@ def main() -> None:
bot = load_bot(bot_config)

prompt: str | TemplatedPrompt
editable = opts.edit
if args:
prompt = TemplatedPrompt.parse(args[0], *args[1:])
elif opts.edit:
editable = False
prompt = open_editor(
drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER
)
else:
if sys.stdin.isatty():
prompt = open_editor("Enter your prompt here...")
else:
prompt = sys.stdin.read()
prompt = sys.stdin.read()

name = drafter.generate_draft(
prompt,
bot,
bot_name=opts.bot,
prompt_transform=open_editor if editable else None,
tool_visitors=[ToolPrinter()],
reset=config.auto_reset if opts.reset is None else opts.reset,
sync=opts.sync,
)
print(f"Generated {name}.")
print(f"Refined {name}.")
elif command == "finalize":
name = drafter.exit_draft(
revert=opts.revert, clean=opts.clean, delete=opts.delete
Expand All @@ -212,9 +219,9 @@ def main() -> None:
tpl = Template.find(name)
if opts.edit:
if tpl:
edit(tpl.source, tpl.local_path())
edit(tpl.local_path(), text=tpl.source)
else:
edit("", Template.local_path_for(name))
edit(Template.local_path_for(name))
else:
if not tpl:
raise ValueError(f"No template named {name!r}")
Expand Down
5 changes: 5 additions & 0 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:

You should stop when and ONLY WHEN all the files you need to change have
been updated.

If you stop for any reason before completing your task, explain why by
updating a REASON file before stopping. For example if you are missing some
information or noticed something inconsistent with the instructions, say so
there. DO NOT STOP without updating at least this file.
"""


Expand Down
50 changes: 34 additions & 16 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import textwrap
import time
from typing import Match, Sequence
from typing import Callable, Match, Sequence

import git

Expand Down Expand Up @@ -77,17 +77,17 @@ def generate_draft(
bot: Bot,
bot_name: str | None = None,
tool_visitors: Sequence[ToolVisitor] | None = None,
prompt_transform: Callable[[str], str] | None = None,
reset: bool = False,
sync: bool = False,
timeout: float | None = None,
) -> str:
if isinstance(prompt, str) and not prompt.strip():
raise ValueError("Empty prompt")
if self._repo.is_dirty(working_tree=False):
if not reset:
raise ValueError("Please commit or reset any staged changes")
self._repo.index.reset()

# Ensure that we are on a draft branch.
branch = _Branch.active(self._repo)
if branch:
self._stage_changes(sync)
Expand All @@ -96,17 +96,18 @@ def generate_draft(
branch = self._create_branch(sync)
_logger.debug("Created branch %s.", branch)

operation_recorder = _OperationRecorder()
tool_visitors = [operation_recorder] + list(tool_visitors or [])
toolbox = StagingToolbox(self._repo, tool_visitors)
# Handle prompt templating and editing.
if isinstance(prompt, TemplatedPrompt):
template: str | None = prompt.template
renderer = PromptRenderer.for_toolbox(toolbox)
renderer = PromptRenderer.for_toolbox(StagingToolbox(self._repo))
prompt_contents = renderer.render(prompt)
else:
template = None
prompt_contents = prompt

if prompt_transform:
prompt_contents = prompt_transform(prompt_contents)
if not prompt_contents.strip():
raise ValueError("Aborting: empty prompt")
with self._store.cursor() as cursor:
[(prompt_id,)] = cursor.execute(
sql("add-prompt"),
Expand All @@ -117,14 +118,19 @@ def generate_draft(
},
)

# Trigger code generation.
_logger.debug("Running bot... [bot=%s]", bot)
operation_recorder = _OperationRecorder()
tool_visitors = [operation_recorder] + list(tool_visitors or [])
toolbox = StagingToolbox(self._repo, tool_visitors)
start_time = time.perf_counter()
goal = Goal(prompt_contents, timeout)
action = bot.act(goal, toolbox)
end_time = time.perf_counter()
walltime = end_time - start_time
_logger.info("Completed bot action. [action=%s]", action)

# Generate an appropriate commit and update our database.
toolbox.trim_index()
title = action.title
if not title:
Expand All @@ -133,7 +139,6 @@ def generate_draft(
f"draft! {title}\n\n{prompt_contents}",
skip_hooks=True,
)

with self._store.cursor() as cursor:
cursor.execute(
sql("add-action"),
Expand All @@ -159,7 +164,7 @@ def generate_draft(
],
)

_logger.info("Generated %s.", branch)
_logger.info("Completed generation for %s.", branch)
return str(branch)

def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
Expand Down Expand Up @@ -232,22 +237,35 @@ def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
def history_table(self, branch_name: str | None = None) -> Table:
path = self._repo.working_dir
branch = _Branch.active(self._repo, branch_name)
if branch:
with self._store.cursor() as cursor:
with self._store.cursor() as cursor:
if branch:
results = cursor.execute(
sql("list-prompts"),
{
"repo_path": path,
"branch_suffix": branch.suffix,
},
)
return Table.from_cursor(results)
else:
with self._store.cursor() as cursor:
else:
results = cursor.execute(
sql("list-drafts"), {"repo_path": path}
)
return Table.from_cursor(results)
return Table.from_cursor(results)

def latest_draft_prompt(self) -> str | None:
"""Returns the latest prompt for the current draft"""
branch = _Branch.active(self._repo)
if not branch:
return None
with self._store.cursor() as cursor:
result = cursor.execute(
sql("get-latest-prompt"),
{
"repo_path": self._repo.working_dir,
"branch_suffix": branch.suffix,
},
).fetchone()
return result[0] if result else None

def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
Expand Down
5 changes: 4 additions & 1 deletion src/git_draft/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def for_toolbox(cls, toolbox: Toolbox) -> Self:

def render(self, prompt: TemplatedPrompt) -> str:
tpl = self._environment.get_template(f"{prompt.template}.{_extension}")
return tpl.render(prompt.context)
try:
return tpl.render(prompt.context)
except jinja2.UndefinedError as err:
raise ValueError(f"Unable to render template: {err}")


def templates_table() -> Table:
Expand Down
6 changes: 6 additions & 0 deletions src/git_draft/queries/get-latest-prompt.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
select p.contents
from prompts as p
join branches as b on p.branch_suffix = b.suffix
where b.repo_path = :repo_path and b.suffix = :branch_suffix
order by p.id desc
limit 1;
14 changes: 14 additions & 0 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,17 @@ def test_history_table_active_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
table = self._drafter.history_table()
assert table

def test_latest_draft_prompt(self) -> None:
bot = FakeBot()

prompt1 = "First prompt"
self._drafter.generate_draft(prompt1, bot)
assert self._drafter.latest_draft_prompt() == prompt1

prompt2 = "Second prompt"
self._drafter.generate_draft(prompt2, bot)
assert self._drafter.latest_draft_prompt() == prompt2

def test_latest_draft_prompt_no_active_branch(self) -> None:
assert self._drafter.latest_draft_prompt() is None
5 changes: 5 additions & 0 deletions tests/git_draft/prompt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def test_ok(self) -> None:
rendered = self._renderer.render(prompt)
assert "foo" in rendered

def test_missing_variable(self) -> None:
prompt = sut.TemplatedPrompt.parse("add-test")
with pytest.raises(ValueError):
self._renderer.render(prompt)


class TestTemplate:
@pytest.fixture(autouse=True)
Expand Down