Skip to content

Commit e9acfb8

Browse files
authored
feat: support prompt editing (#49)
1 parent 698fe8e commit e9acfb8

File tree

9 files changed

+98
-29
lines changed

9 files changed

+98
-29
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
## Highlights
88

9-
* Concurrent edits.
9+
* Concurrent edits. By default `git-draft` does not touch the working directory.
1010
* Customizable prompt templates.
1111
* Extensible bot API.
1212

@@ -16,3 +16,15 @@
1616
```sh
1717
pipx install git-draft[openai]
1818
```
19+
20+
21+
## Next steps
22+
23+
* Mechanism for reporting feedback from a bot, and possibly allowing user to
24+
interactively respond.
25+
* Add configuration option to auto sync and `--no-sync` flag. Similar to reset.
26+
* Add "amend" commit when finalizing. This could be useful training data,
27+
showing what the bot did not get right.
28+
* Convenience functionality for simple cases: checkout option which applies the
29+
changes, and finalizes the draft if specified multiple times. For example `git
30+
draft -cc add-test symbol=foo`

docs/git-draft.adoc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ IMPORTANT: `git-draft` is WIP.
1818
== Synopsis
1919

2020
[verse]
21-
git draft [options] [--generate] [--bot BOT] [--edit] [--reset | --no-reset]
22-
[--sync] [TEMPLATE [VARIABLE...]]
21+
git draft [options] [--generate] [--bot BOT] [--edit] [--reset | --no-reset] [--sync] [TEMPLATE [VARIABLE...]]
2322
git draft [options] --finalize [--clean | --revert] [--delete]
2423
git draft [options] --show-drafts [--json]
2524
git draft [options] --show-prompts [--json] [PROMPT]

src/git_draft/__main__.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,20 @@ def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None:
142142
print(f"Deleted {path}.")
143143

144144

145-
def edit(text: str | None, path: Path | None) -> str | None:
145+
def edit(path: Path, text: str | None = None) -> str | None:
146146
if sys.stdin.isatty():
147147
return open_editor(text or "", path)
148148
else:
149-
if path and text is not None:
149+
if text is not None:
150150
with open(path, "w") as f:
151151
f.write(text)
152152
print(path)
153153
return None
154154

155155

156+
_PROMPT_PLACEHOLDER = "Enter your prompt here..."
157+
158+
156159
def main() -> None:
157160
config = Config.load()
158161
(opts, args) = new_parser().parse_args()
@@ -177,23 +180,27 @@ def main() -> None:
177180
bot = load_bot(bot_config)
178181

179182
prompt: str | TemplatedPrompt
183+
editable = opts.edit
180184
if args:
181185
prompt = TemplatedPrompt.parse(args[0], *args[1:])
186+
elif opts.edit:
187+
editable = False
188+
prompt = open_editor(
189+
drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER
190+
)
182191
else:
183-
if sys.stdin.isatty():
184-
prompt = open_editor("Enter your prompt here...")
185-
else:
186-
prompt = sys.stdin.read()
192+
prompt = sys.stdin.read()
187193

188194
name = drafter.generate_draft(
189195
prompt,
190196
bot,
191197
bot_name=opts.bot,
198+
prompt_transform=open_editor if editable else None,
192199
tool_visitors=[ToolPrinter()],
193200
reset=config.auto_reset if opts.reset is None else opts.reset,
194201
sync=opts.sync,
195202
)
196-
print(f"Generated {name}.")
203+
print(f"Refined {name}.")
197204
elif command == "finalize":
198205
name = drafter.exit_draft(
199206
revert=opts.revert, clean=opts.clean, delete=opts.delete
@@ -212,9 +219,9 @@ def main() -> None:
212219
tpl = Template.find(name)
213220
if opts.edit:
214221
if tpl:
215-
edit(tpl.source, tpl.local_path())
222+
edit(tpl.local_path(), text=tpl.source)
216223
else:
217-
edit("", Template.local_path_for(name))
224+
edit(Template.local_path_for(name))
218225
else:
219226
if not tpl:
220227
raise ValueError(f"No template named {name!r}")

src/git_draft/bots/openai.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:
126126
127127
You should stop when and ONLY WHEN all the files you need to change have
128128
been updated.
129+
130+
If you stop for any reason before completing your task, explain why by
131+
updating a REASON file before stopping. For example if you are missing some
132+
information or noticed something inconsistent with the instructions, say so
133+
there. DO NOT STOP without updating at least this file.
129134
"""
130135

131136

src/git_draft/drafter.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import re
1111
import textwrap
1212
import time
13-
from typing import Match, Sequence
13+
from typing import Callable, Match, Sequence
1414

1515
import git
1616

@@ -77,17 +77,17 @@ def generate_draft(
7777
bot: Bot,
7878
bot_name: str | None = None,
7979
tool_visitors: Sequence[ToolVisitor] | None = None,
80+
prompt_transform: Callable[[str], str] | None = None,
8081
reset: bool = False,
8182
sync: bool = False,
8283
timeout: float | None = None,
8384
) -> str:
84-
if isinstance(prompt, str) and not prompt.strip():
85-
raise ValueError("Empty prompt")
8685
if self._repo.is_dirty(working_tree=False):
8786
if not reset:
8887
raise ValueError("Please commit or reset any staged changes")
8988
self._repo.index.reset()
9089

90+
# Ensure that we are on a draft branch.
9191
branch = _Branch.active(self._repo)
9292
if branch:
9393
self._stage_changes(sync)
@@ -96,17 +96,18 @@ def generate_draft(
9696
branch = self._create_branch(sync)
9797
_logger.debug("Created branch %s.", branch)
9898

99-
operation_recorder = _OperationRecorder()
100-
tool_visitors = [operation_recorder] + list(tool_visitors or [])
101-
toolbox = StagingToolbox(self._repo, tool_visitors)
99+
# Handle prompt templating and editing.
102100
if isinstance(prompt, TemplatedPrompt):
103101
template: str | None = prompt.template
104-
renderer = PromptRenderer.for_toolbox(toolbox)
102+
renderer = PromptRenderer.for_toolbox(StagingToolbox(self._repo))
105103
prompt_contents = renderer.render(prompt)
106104
else:
107105
template = None
108106
prompt_contents = prompt
109-
107+
if prompt_transform:
108+
prompt_contents = prompt_transform(prompt_contents)
109+
if not prompt_contents.strip():
110+
raise ValueError("Aborting: empty prompt")
110111
with self._store.cursor() as cursor:
111112
[(prompt_id,)] = cursor.execute(
112113
sql("add-prompt"),
@@ -117,14 +118,19 @@ def generate_draft(
117118
},
118119
)
119120

121+
# Trigger code generation.
120122
_logger.debug("Running bot... [bot=%s]", bot)
123+
operation_recorder = _OperationRecorder()
124+
tool_visitors = [operation_recorder] + list(tool_visitors or [])
125+
toolbox = StagingToolbox(self._repo, tool_visitors)
121126
start_time = time.perf_counter()
122127
goal = Goal(prompt_contents, timeout)
123128
action = bot.act(goal, toolbox)
124129
end_time = time.perf_counter()
125130
walltime = end_time - start_time
126131
_logger.info("Completed bot action. [action=%s]", action)
127132

133+
# Generate an appropriate commit and update our database.
128134
toolbox.trim_index()
129135
title = action.title
130136
if not title:
@@ -133,7 +139,6 @@ def generate_draft(
133139
f"draft! {title}\n\n{prompt_contents}",
134140
skip_hooks=True,
135141
)
136-
137142
with self._store.cursor() as cursor:
138143
cursor.execute(
139144
sql("add-action"),
@@ -159,7 +164,7 @@ def generate_draft(
159164
],
160165
)
161166

162-
_logger.info("Generated %s.", branch)
167+
_logger.info("Completed generation for %s.", branch)
163168
return str(branch)
164169

165170
def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
@@ -232,22 +237,35 @@ def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
232237
def history_table(self, branch_name: str | None = None) -> Table:
233238
path = self._repo.working_dir
234239
branch = _Branch.active(self._repo, branch_name)
235-
if branch:
236-
with self._store.cursor() as cursor:
240+
with self._store.cursor() as cursor:
241+
if branch:
237242
results = cursor.execute(
238243
sql("list-prompts"),
239244
{
240245
"repo_path": path,
241246
"branch_suffix": branch.suffix,
242247
},
243248
)
244-
return Table.from_cursor(results)
245-
else:
246-
with self._store.cursor() as cursor:
249+
else:
247250
results = cursor.execute(
248251
sql("list-drafts"), {"repo_path": path}
249252
)
250-
return Table.from_cursor(results)
253+
return Table.from_cursor(results)
254+
255+
def latest_draft_prompt(self) -> str | None:
256+
"""Returns the latest prompt for the current draft"""
257+
branch = _Branch.active(self._repo)
258+
if not branch:
259+
return None
260+
with self._store.cursor() as cursor:
261+
result = cursor.execute(
262+
sql("get-latest-prompt"),
263+
{
264+
"repo_path": self._repo.working_dir,
265+
"branch_suffix": branch.suffix,
266+
},
267+
).fetchone()
268+
return result[0] if result else None
251269

252270
def _create_branch(self, sync: bool) -> _Branch:
253271
if self._repo.head.is_detached:

src/git_draft/prompt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def for_toolbox(cls, toolbox: Toolbox) -> Self:
5656

5757
def render(self, prompt: TemplatedPrompt) -> str:
5858
tpl = self._environment.get_template(f"{prompt.template}.{_extension}")
59-
return tpl.render(prompt.context)
59+
try:
60+
return tpl.render(prompt.context)
61+
except jinja2.UndefinedError as err:
62+
raise ValueError(f"Unable to render template: {err}")
6063

6164

6265
def templates_table() -> Table:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
select p.contents
2+
from prompts as p
3+
join branches as b on p.branch_suffix = b.suffix
4+
where b.repo_path = :repo_path and b.suffix = :branch_suffix
5+
order by p.id desc
6+
limit 1;

tests/git_draft/drafter_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,17 @@ def test_history_table_active_draft(self) -> None:
243243
self._drafter.generate_draft("hello", FakeBot())
244244
table = self._drafter.history_table()
245245
assert table
246+
247+
def test_latest_draft_prompt(self) -> None:
248+
bot = FakeBot()
249+
250+
prompt1 = "First prompt"
251+
self._drafter.generate_draft(prompt1, bot)
252+
assert self._drafter.latest_draft_prompt() == prompt1
253+
254+
prompt2 = "Second prompt"
255+
self._drafter.generate_draft(prompt2, bot)
256+
assert self._drafter.latest_draft_prompt() == prompt2
257+
258+
def test_latest_draft_prompt_no_active_branch(self) -> None:
259+
assert self._drafter.latest_draft_prompt() is None

tests/git_draft/prompt_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def test_ok(self) -> None:
1515
rendered = self._renderer.render(prompt)
1616
assert "foo" in rendered
1717

18+
def test_missing_variable(self) -> None:
19+
prompt = sut.TemplatedPrompt.parse("add-test")
20+
with pytest.raises(ValueError):
21+
self._renderer.render(prompt)
22+
1823

1924
class TestTemplate:
2025
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)