Skip to content

Commit a6e2fa4

Browse files
authored
feat: support async bot actions (#82)
This will be useful for the MCP bot, since it appears MCP clients use an async API [1]. [1]: https://modelcontextprotocol.io/quickstart/client
1 parent d2e6ae6 commit a6e2fa4

File tree

7 files changed

+82
-46
lines changed

7 files changed

+82
-46
lines changed

poetry.lock

Lines changed: 26 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ python = ">=3.12,<4"
4141
coverage = "^7.4.4"
4242
mypy = "^1.2.1"
4343
poethepoet = "^0.25.0"
44-
pytest = "^7.1.0"
44+
pytest = "^8.2.0"
45+
pytest-asyncio = "^0.26.0"
4546
ruff = "^0.11.0"
4647

4748
# Poe

src/git_draft/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import enum
67
import importlib.metadata
78
import logging
@@ -151,7 +152,7 @@ def edit(*, path: Path | None = None, text: str | None = None) -> str:
151152
_PROMPT_PLACEHOLDER = "Enter your prompt here..."
152153

153154

154-
def main() -> None: # noqa: PLR0912 PLR0915
155+
async def main() -> None: # noqa: PLR0912 PLR0915
155156
config = Config.load()
156157
(opts, args) = new_parser().parse_args()
157158

@@ -198,7 +199,7 @@ def main() -> None: # noqa: PLR0912 PLR0915
198199
editable = False # We already edited the prompt
199200

200201
accept = Accept(opts.accept or 0)
201-
drafter.generate_draft(
202+
await drafter.generate_draft(
202203
prompt,
203204
bot,
204205
prompt_transform=open_editor if editable else None,
@@ -231,7 +232,7 @@ def main() -> None: # noqa: PLR0912 PLR0915
231232

232233
if __name__ == "__main__":
233234
try:
234-
main()
235+
asyncio.run(main())
235236
except Exception as err:
236237
_logger.exception("Program failed.")
237238
print(f"Error: {err}", file=sys.stderr)

src/git_draft/bots/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,6 @@ def state_folder_path(cls, ensure_exists: bool = False) -> Path:
6666
path.mkdir(parents=True, exist_ok=True)
6767
return path
6868

69-
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
69+
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
7070
"""Runs the bot, striving to achieve the goal with the given toolbox"""
7171
raise NotImplementedError()

src/git_draft/bots/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
234234
self._client = client
235235
self._model = model
236236

237-
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
237+
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
238238
tools = _ToolsFactory(strict=False).params()
239239
tool_handler = _CompletionsToolHandler(toolbox)
240240

@@ -318,7 +318,7 @@ def _load_assistant_id(self) -> str:
318318
f.write(assistant_id)
319319
return assistant_id
320320

321-
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
321+
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
322322
assistant_id = self._load_assistant_id()
323323

324324
thread = self._client.beta.threads.create()

src/git_draft/drafter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def create(cls, repo: Repo, store: Store, feedback: Feedback) -> Drafter:
100100
cursor.executescript(sql("create-tables"))
101101
return cls(store, repo, feedback)
102102

103-
def generate_draft(
103+
async def generate_draft(
104104
self,
105105
prompt: str | TemplatedPrompt,
106106
bot: Bot,
@@ -143,7 +143,7 @@ def generate_draft(
143143
# Run the bot to generate the change.
144144
operation_recorder = _OperationRecorder(self._feedback)
145145
with self._feedback.spinner("Running bot...") as spinner:
146-
change = self._generate_change(
146+
change = await self._generate_change(
147147
bot,
148148
Goal(prompt_contents),
149149
toolbox.with_visitors(
@@ -342,7 +342,7 @@ def _prepare_prompt(
342342
raise ValueError("Missing or empty prompt")
343343
return contents
344344

345-
def _generate_change(
345+
async def _generate_change(
346346
self,
347347
bot: Bot,
348348
goal: Goal,
@@ -352,7 +352,7 @@ def _generate_change(
352352

353353
start_time = time.perf_counter()
354354
_logger.debug("Running bot... [bot=%s]", bot)
355-
action = bot.act(goal, toolbox)
355+
action = await bot.act(goal, toolbox)
356356
_logger.info("Completed bot action. [action=%s]", action)
357357
end_time = time.perf_counter()
358358

tests/git_draft/drafter_test.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def noop(cls) -> Self:
2929
def prompt(cls) -> Self:
3030
return cls({"PROMPT": lambda goal: goal.prompt})
3131

32-
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
32+
async def act(self, goal: Goal, toolbox: Toolbox) -> Action:
3333
for key, value in self._contents.items():
3434
path = PurePosixPath(key)
3535
if value is None:
@@ -62,37 +62,41 @@ def _commit_files(self, ref: str) -> frozenset[str]:
6262
def _checkout(self) -> None:
6363
self._repo.git("checkout", "--", ".")
6464

65-
def test_generate_draft(self) -> None:
65+
@pytest.mark.asyncio
66+
async def test_generate_draft(self) -> None:
6667
self._fs.write("p1", "a")
67-
self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"}))
68+
await self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"}))
6869
assert len(self._commits()) == 1
6970
assert len(self._commits("@{u}")) == 3
7071
assert self._fs.read("p1") == "a"
7172

72-
def test_generate_empty_draft(self) -> None:
73-
self._drafter.generate_draft("hello", _SimpleBot.noop())
73+
@pytest.mark.asyncio
74+
async def test_generate_empty_draft(self) -> None:
75+
await self._drafter.generate_draft("hello", _SimpleBot.noop())
7476
assert len(self._commits()) == 1
7577
assert len(self._commits("@{u}")) == 2
7678

77-
def test_generate_draft_merge(self) -> None:
79+
@pytest.mark.asyncio
80+
async def test_generate_draft_merge(self) -> None:
7881
self._fs.write("p1", "a")
7982

80-
self._drafter.generate_draft(
83+
await self._drafter.generate_draft(
8184
"hello", _SimpleBot({"p2": "b"}), merge_strategy="ignore-all-space"
8285
)
8386
# No sync(merge) commit since no changes happened between.
8487
assert len(self._commits()) == 4 # init, sync(prompt), prompt, merge
8588
assert self._fs.read("p1") == "a"
8689
assert self._fs.read("p2") == "b"
8790

88-
def test_generate_draft_merge_no_conflict(self) -> None:
91+
@pytest.mark.asyncio
92+
async def test_generate_draft_merge_no_conflict(self) -> None:
8993
self._fs.write("p1", "a")
9094

9195
def update(_goal: Goal) -> str:
9296
self._fs.write("p2", "b")
9397
return "A"
9498

95-
self._drafter.generate_draft(
99+
await self._drafter.generate_draft(
96100
"hello",
97101
_SimpleBot({"p1": update}),
98102
merge_strategy="ignore-all-space",
@@ -101,70 +105,81 @@ def update(_goal: Goal) -> str:
101105
assert self._fs.read("p1") == "A"
102106
assert self._fs.read("p2") == "b"
103107

104-
def test_generate_draft_merge_theirs(self) -> None:
108+
@pytest.mark.asyncio
109+
async def test_generate_draft_merge_theirs(self) -> None:
105110
self._fs.write("p1", "a")
106111

107112
def update(_goal: Goal) -> str:
108113
self._fs.write("p1", "b")
109114
return "A"
110115

111-
self._drafter.generate_draft(
116+
await self._drafter.generate_draft(
112117
"hello", _SimpleBot({"p1": update}), merge_strategy="theirs"
113118
)
114119
# sync(merge) commit here since p1 was updated separately.
115120
assert len(self._commits()) == 5 # init, sync, prompt, sync, merge
116121
assert self._fs.read("p1") == "A"
117122

118-
def test_generate_draft_merge_conflict(self) -> None:
123+
@pytest.mark.asyncio
124+
async def test_generate_draft_merge_conflict(self) -> None:
119125
self._fs.write("p1", "a")
120126

121127
def update(_goal: Goal) -> str:
122128
self._fs.write("p1", "b")
123129
return "A"
124130

125131
with pytest.raises(GitError):
126-
self._drafter.generate_draft(
132+
await self._drafter.generate_draft(
127133
"hello",
128134
_SimpleBot({"p1": update}),
129135
merge_strategy="ignore-all-space",
130136
)
131137

132-
def test_generate_outside_branch(self) -> None:
138+
@pytest.mark.asyncio
139+
async def test_generate_outside_branch(self) -> None:
133140
self._repo.git("checkout", "--detach")
134141
with pytest.raises(RuntimeError):
135-
self._drafter.generate_draft("ok", _SimpleBot.noop())
142+
await self._drafter.generate_draft("ok", _SimpleBot.noop())
136143

137-
def test_generate_empty_prompt(self) -> None:
144+
@pytest.mark.asyncio
145+
async def test_generate_empty_prompt(self) -> None:
138146
with pytest.raises(ValueError):
139-
self._drafter.generate_draft("", _SimpleBot.noop())
147+
await self._drafter.generate_draft("", _SimpleBot.noop())
140148

141-
def test_generate_reuse_branch(self) -> None:
149+
@pytest.mark.asyncio
150+
async def test_generate_reuse_branch(self) -> None:
142151
bot = _SimpleBot({"prompt": lambda goal: goal.prompt})
143-
self._drafter.generate_draft("prompt1", bot, "theirs")
144-
self._drafter.generate_draft("prompt2", bot, "theirs")
152+
await self._drafter.generate_draft("prompt1", bot, "theirs")
153+
await self._drafter.generate_draft("prompt2", bot, "theirs")
145154
assert self._fs.read("prompt") == "prompt2"
146155

147-
def test_delete_unknown_file(self) -> None:
148-
self._drafter.generate_draft("hello", _SimpleBot({"p1": None}))
156+
@pytest.mark.asyncio
157+
async def test_delete_unknown_file(self) -> None:
158+
await self._drafter.generate_draft("hello", _SimpleBot({"p1": None}))
149159

150-
def test_quit_keeps_changes(self) -> None:
160+
@pytest.mark.asyncio
161+
async def test_quit_keeps_changes(self) -> None:
151162
self._fs.write("p1.txt", "a1")
152-
self._drafter.generate_draft("hello", _SimpleBot.prompt(), "theirs")
163+
await self._drafter.generate_draft(
164+
"hello", _SimpleBot.prompt(), "theirs"
165+
)
153166
self._fs.write("p1.txt", "a2")
154167
self._drafter.quit_folio()
155168
assert self._fs.read("p1.txt") == "a2"
156169
assert self._fs.read("PROMPT") == "hello"
157170

158-
def test_latest_draft_prompt(self) -> None:
171+
@pytest.mark.asyncio
172+
async def test_latest_draft_prompt(self) -> None:
159173
bot = _SimpleBot.noop()
160174

161175
prompt1 = "First prompt"
162-
self._drafter.generate_draft(prompt1, bot)
176+
await self._drafter.generate_draft(prompt1, bot)
163177
assert self._drafter.latest_draft_prompt() == prompt1
164178

165179
prompt2 = "Second prompt"
166-
self._drafter.generate_draft(prompt2, bot)
180+
await self._drafter.generate_draft(prompt2, bot)
167181
assert self._drafter.latest_draft_prompt() == prompt2
168182

169-
def test_latest_draft_prompt_no_active_branch(self) -> None:
183+
@pytest.mark.asyncio
184+
async def test_latest_draft_prompt_no_active_branch(self) -> None:
170185
assert self._drafter.latest_draft_prompt() is None

0 commit comments

Comments
 (0)