Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ python = ">=3.12,<4"
coverage = "^7.4.4"
mypy = "^1.2.1"
poethepoet = "^0.25.0"
pytest = "^7.1.0"
pytest = "^8.2.0"
pytest-asyncio = "^0.26.0"
ruff = "^0.11.0"

# Poe
Expand Down
7 changes: 4 additions & 3 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import enum
import importlib.metadata
import logging
Expand Down Expand Up @@ -151,7 +152,7 @@ def edit(*, path: Path | None = None, text: str | None = None) -> str:
_PROMPT_PLACEHOLDER = "Enter your prompt here..."


def main() -> None: # noqa: PLR0912 PLR0915
async def main() -> None: # noqa: PLR0912 PLR0915
config = Config.load()
(opts, args) = new_parser().parse_args()

Expand Down Expand Up @@ -198,7 +199,7 @@ def main() -> None: # noqa: PLR0912 PLR0915
editable = False # We already edited the prompt

accept = Accept(opts.accept or 0)
drafter.generate_draft(
await drafter.generate_draft(
prompt,
bot,
prompt_transform=open_editor if editable else None,
Expand Down Expand Up @@ -231,7 +232,7 @@ def main() -> None: # noqa: PLR0912 PLR0915

if __name__ == "__main__":
try:
main()
asyncio.run(main())
except Exception as err:
_logger.exception("Program failed.")
print(f"Error: {err}", file=sys.stderr)
Expand Down
2 changes: 1 addition & 1 deletion src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ def state_folder_path(cls, ensure_exists: bool = False) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
"""Runs the bot, striving to achieve the goal with the given toolbox"""
raise NotImplementedError()
4 changes: 2 additions & 2 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
self._client = client
self._model = model

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
tools = _ToolsFactory(strict=False).params()
tool_handler = _CompletionsToolHandler(toolbox)

Expand Down Expand Up @@ -318,7 +318,7 @@ def _load_assistant_id(self) -> str:
f.write(assistant_id)
return assistant_id

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
assistant_id = self._load_assistant_id()

thread = self._client.beta.threads.create()
Expand Down
8 changes: 4 additions & 4 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create(cls, repo: Repo, store: Store, feedback: Feedback) -> Drafter:
cursor.executescript(sql("create-tables"))
return cls(store, repo, feedback)

def generate_draft(
async def generate_draft(
self,
prompt: str | TemplatedPrompt,
bot: Bot,
Expand Down Expand Up @@ -143,7 +143,7 @@ def generate_draft(
# Run the bot to generate the change.
operation_recorder = _OperationRecorder(self._feedback)
with self._feedback.spinner("Running bot...") as spinner:
change = self._generate_change(
change = await self._generate_change(
bot,
Goal(prompt_contents),
toolbox.with_visitors(
Expand Down Expand Up @@ -342,7 +342,7 @@ def _prepare_prompt(
raise ValueError("Missing or empty prompt")
return contents

def _generate_change(
async def _generate_change(
self,
bot: Bot,
goal: Goal,
Expand All @@ -352,7 +352,7 @@ def _generate_change(

start_time = time.perf_counter()
_logger.debug("Running bot... [bot=%s]", bot)
action = bot.act(goal, toolbox)
action = await bot.act(goal, toolbox)
_logger.info("Completed bot action. [action=%s]", action)
end_time = time.perf_counter()

Expand Down
71 changes: 43 additions & 28 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def noop(cls) -> Self:
def prompt(cls) -> Self:
return cls({"PROMPT": lambda goal: goal.prompt})

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
for key, value in self._contents.items():
path = PurePosixPath(key)
if value is None:
Expand Down Expand Up @@ -62,37 +62,41 @@ def _commit_files(self, ref: str) -> frozenset[str]:
def _checkout(self) -> None:
self._repo.git("checkout", "--", ".")

def test_generate_draft(self) -> None:
@pytest.mark.asyncio
async def test_generate_draft(self) -> None:
self._fs.write("p1", "a")
self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"}))
await self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"}))
assert len(self._commits()) == 1
assert len(self._commits("@{u}")) == 3
assert self._fs.read("p1") == "a"

def test_generate_empty_draft(self) -> None:
self._drafter.generate_draft("hello", _SimpleBot.noop())
@pytest.mark.asyncio
async def test_generate_empty_draft(self) -> None:
await self._drafter.generate_draft("hello", _SimpleBot.noop())
assert len(self._commits()) == 1
assert len(self._commits("@{u}")) == 2

def test_generate_draft_merge(self) -> None:
@pytest.mark.asyncio
async def test_generate_draft_merge(self) -> None:
self._fs.write("p1", "a")

self._drafter.generate_draft(
await self._drafter.generate_draft(
"hello", _SimpleBot({"p2": "b"}), merge_strategy="ignore-all-space"
)
# No sync(merge) commit since no changes happened between.
assert len(self._commits()) == 4 # init, sync(prompt), prompt, merge
assert self._fs.read("p1") == "a"
assert self._fs.read("p2") == "b"

def test_generate_draft_merge_no_conflict(self) -> None:
@pytest.mark.asyncio
async def test_generate_draft_merge_no_conflict(self) -> None:
self._fs.write("p1", "a")

def update(_goal: Goal) -> str:
self._fs.write("p2", "b")
return "A"

self._drafter.generate_draft(
await self._drafter.generate_draft(
"hello",
_SimpleBot({"p1": update}),
merge_strategy="ignore-all-space",
Expand All @@ -101,70 +105,81 @@ def update(_goal: Goal) -> str:
assert self._fs.read("p1") == "A"
assert self._fs.read("p2") == "b"

def test_generate_draft_merge_theirs(self) -> None:
@pytest.mark.asyncio
async def test_generate_draft_merge_theirs(self) -> None:
self._fs.write("p1", "a")

def update(_goal: Goal) -> str:
self._fs.write("p1", "b")
return "A"

self._drafter.generate_draft(
await self._drafter.generate_draft(
"hello", _SimpleBot({"p1": update}), merge_strategy="theirs"
)
# sync(merge) commit here since p1 was updated separately.
assert len(self._commits()) == 5 # init, sync, prompt, sync, merge
assert self._fs.read("p1") == "A"

def test_generate_draft_merge_conflict(self) -> None:
@pytest.mark.asyncio
async def test_generate_draft_merge_conflict(self) -> None:
self._fs.write("p1", "a")

def update(_goal: Goal) -> str:
self._fs.write("p1", "b")
return "A"

with pytest.raises(GitError):
self._drafter.generate_draft(
await self._drafter.generate_draft(
"hello",
_SimpleBot({"p1": update}),
merge_strategy="ignore-all-space",
)

def test_generate_outside_branch(self) -> None:
@pytest.mark.asyncio
async def test_generate_outside_branch(self) -> None:
self._repo.git("checkout", "--detach")
with pytest.raises(RuntimeError):
self._drafter.generate_draft("ok", _SimpleBot.noop())
await self._drafter.generate_draft("ok", _SimpleBot.noop())

def test_generate_empty_prompt(self) -> None:
@pytest.mark.asyncio
async def test_generate_empty_prompt(self) -> None:
with pytest.raises(ValueError):
self._drafter.generate_draft("", _SimpleBot.noop())
await self._drafter.generate_draft("", _SimpleBot.noop())

def test_generate_reuse_branch(self) -> None:
@pytest.mark.asyncio
async def test_generate_reuse_branch(self) -> None:
bot = _SimpleBot({"prompt": lambda goal: goal.prompt})
self._drafter.generate_draft("prompt1", bot, "theirs")
self._drafter.generate_draft("prompt2", bot, "theirs")
await self._drafter.generate_draft("prompt1", bot, "theirs")
await self._drafter.generate_draft("prompt2", bot, "theirs")
assert self._fs.read("prompt") == "prompt2"

def test_delete_unknown_file(self) -> None:
self._drafter.generate_draft("hello", _SimpleBot({"p1": None}))
@pytest.mark.asyncio
async def test_delete_unknown_file(self) -> None:
await self._drafter.generate_draft("hello", _SimpleBot({"p1": None}))

def test_quit_keeps_changes(self) -> None:
@pytest.mark.asyncio
async def test_quit_keeps_changes(self) -> None:
self._fs.write("p1.txt", "a1")
self._drafter.generate_draft("hello", _SimpleBot.prompt(), "theirs")
await self._drafter.generate_draft(
"hello", _SimpleBot.prompt(), "theirs"
)
self._fs.write("p1.txt", "a2")
self._drafter.quit_folio()
assert self._fs.read("p1.txt") == "a2"
assert self._fs.read("PROMPT") == "hello"

def test_latest_draft_prompt(self) -> None:
@pytest.mark.asyncio
async def test_latest_draft_prompt(self) -> None:
bot = _SimpleBot.noop()

prompt1 = "First prompt"
self._drafter.generate_draft(prompt1, bot)
await self._drafter.generate_draft(prompt1, bot)
assert self._drafter.latest_draft_prompt() == prompt1

prompt2 = "Second prompt"
self._drafter.generate_draft(prompt2, bot)
await self._drafter.generate_draft(prompt2, bot)
assert self._drafter.latest_draft_prompt() == prompt2

def test_latest_draft_prompt_no_active_branch(self) -> None:
@pytest.mark.asyncio
async def test_latest_draft_prompt_no_active_branch(self) -> None:
assert self._drafter.latest_draft_prompt() is None