Skip to content

Commit 12f467f

Browse files
authored
feat: implement run-based OpenAI assistant (#16)
1 parent 6eb7b6e commit 12f467f

File tree

3 files changed

+144
-69
lines changed

3 files changed

+144
-69
lines changed

src/git_draft/assistants/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class Toolbox(Protocol):
99
def list_files(self) -> Sequence[PurePosixPath]: ...
1010
def read_file(self, path: PurePosixPath) -> str: ...
11-
def write_file(self, path: PurePosixPath, data: str) -> None: ...
11+
def write_file(self, path: PurePosixPath, contents: str) -> None: ...
1212

1313

1414
@dataclasses.dataclass(frozen=True)

src/git_draft/assistants/openai.py

Lines changed: 131 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,85 @@
1+
import json
2+
import logging
13
import openai
24
from pathlib import PurePosixPath
5+
import textwrap
6+
from typing import Any, Mapping, Self, Sequence, override
37

48
from .common import Assistant, Session, Toolbox
59

610

7-
# https://aider.chat/docs/more-info.html
8-
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
9-
_INSTRUCTIONS = """\
10-
You are an expert software engineer, who writes correct and concise code.
11-
"""
11+
_logger = logging.getLogger(__name__)
1212

13-
_tools = [ # TODO
14-
{
13+
14+
def _function_tool_param(
15+
name: str,
16+
description: str,
17+
inputs: Mapping[str, Any] | None = None,
18+
required_inputs: Sequence[str] | None = None,
19+
) -> openai.types.beta.FunctionToolParam:
20+
return {
1521
"type": "function",
1622
"function": {
17-
"name": "read_file",
18-
"description": "Get a file's contents",
23+
"name": name,
24+
"description": textwrap.dedent(description),
1925
"parameters": {
2026
"type": "object",
21-
"properties": {
22-
"path": {
23-
"type": "string",
24-
"description": "Path of the file to be read",
25-
},
26-
},
27-
"required": ["path"],
27+
"additionalProperties": False,
28+
"properties": inputs or {},
29+
"required": required_inputs or [],
2830
},
31+
"strict": True,
2932
},
30-
},
31-
{
32-
"type": "function",
33-
"function": {
34-
"name": "write_file",
35-
"description": "Update a file's contents",
36-
"parameters": {
37-
"type": "object",
38-
"properties": {
39-
"path": {
40-
"type": "string",
41-
"description": "Path of the file to be updated",
42-
},
43-
"contents": {
44-
"type": "string",
45-
"description": "New contents of the file",
46-
},
47-
},
48-
"required": ["path", "contents"],
33+
}
34+
35+
36+
_tools = [
37+
_function_tool_param(
38+
name="list_files",
39+
description="List all available files",
40+
),
41+
_function_tool_param(
42+
name="read_file",
43+
description="Get a file's contents",
44+
inputs={
45+
"path": {
46+
"type": "string",
47+
"description": "Path of the file to be read",
48+
},
49+
},
50+
required_inputs=["path"],
51+
),
52+
_function_tool_param(
53+
name="write_file",
54+
description="""\
55+
Set a file's contents
56+
57+
The file will be created if it does not already exist.
58+
""",
59+
inputs={
60+
"path": {
61+
"type": "string",
62+
"description": "Path of the file to be updated",
63+
},
64+
"contents": {
65+
"type": "string",
66+
"description": "New contents of the file",
4967
},
5068
},
51-
},
69+
required_inputs=["path", "contents"],
70+
),
5271
]
5372

5473

74+
# https://aider.chat/docs/more-info.html
75+
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
76+
_INSTRUCTIONS = """\
77+
You are an expert software engineer, who writes correct and concise code.
78+
Use the provided functions to find the filesyou need to answer the query,
79+
read the content of the relevant ones, and save the changes you suggest.
80+
"""
81+
82+
5583
class OpenAIAssistant(Assistant):
5684
"""An OpenAI-backed assistant
5785
@@ -66,26 +94,73 @@ def __init__(self) -> None:
6694
self._client = openai.OpenAI()
6795

6896
def run(self, prompt: str, toolbox: Toolbox) -> Session:
69-
# TODO: Switch to the thread run API, using tools to leverage toolbox
70-
# methods.
71-
# assistant = client.beta.assistants.create(
72-
# instructions=_INSTRUCTIONS,
73-
# model="gpt-4o",
74-
# tools=_tools,
75-
# )
76-
# thread = client.beta.threads.create()
77-
# message = client.beta.threads.messages.create(
78-
# thread_id=thread.id,
79-
# role="user",
80-
# content="What's the weather in San Francisco today and the likelihood it'll rain?",
81-
# )
82-
completion = self._client.chat.completions.create(
83-
messages=[
84-
{"role": "system", "content": _INSTRUCTIONS},
85-
{"role": "user", "content": prompt},
86-
],
97+
# TODO: Reuse assistant.
98+
assistant = self._client.beta.assistants.create(
99+
instructions=_INSTRUCTIONS,
87100
model="gpt-4o",
101+
tools=_tools,
102+
)
103+
thread = self._client.beta.threads.create()
104+
105+
message = self._client.beta.threads.messages.create(
106+
thread_id=thread.id,
107+
role="user",
108+
content=prompt,
88109
)
89-
content = completion.choices[0].message.content or ""
90-
toolbox.write_file(PurePosixPath(f"{completion.id}.txt"), content)
110+
print(message)
111+
112+
with self._client.beta.threads.runs.stream(
113+
thread_id=thread.id,
114+
assistant_id=assistant.id,
115+
event_handler=_EventHandler(self._client, toolbox),
116+
) as stream:
117+
stream.until_done()
118+
91119
return Session(0)
120+
121+
122+
class _EventHandler(openai.AssistantEventHandler):
123+
def __init__(self, client: openai.Client, toolbox: Toolbox) -> None:
124+
super().__init__()
125+
self._client = client
126+
self._toolbox = toolbox
127+
128+
def clone(self) -> Self:
129+
return self.__class__(self._client, self._toolbox)
130+
131+
@override
132+
def on_event(self, event: Any) -> None:
133+
_logger.debug("Event: %s", event)
134+
if event.event == "thread.run.requires_action":
135+
run_id = event.data.id # Retrieve the run ID from the event data
136+
self._handle_action(run_id, event.data)
137+
# TODO: Handle (log?) other events.
138+
139+
def _handle_action(self, run_id: str, data: Any) -> None:
140+
tool_outputs = list[Any]()
141+
for tool in data.required_action.submit_tool_outputs.tool_calls:
142+
name = tool.function.name
143+
inputs = json.loads(tool.function.arguments)
144+
_logger.info("Requested tool: %s", tool)
145+
if name == "read_file":
146+
path = PurePosixPath(inputs["path"])
147+
output = self._toolbox.read_file(path)
148+
elif name == "write_file":
149+
path = PurePosixPath(inputs["path"])
150+
contents = inputs["contents"]
151+
self._toolbox.write_file(path, contents)
152+
output = "OK"
153+
elif name == "list_files":
154+
assert not inputs
155+
output = "\n".join(str(p) for p in self._toolbox.list_files())
156+
tool_outputs.append({"tool_call_id": tool.id, "output": output})
157+
158+
run = self.current_run
159+
assert run, "No ongoing run"
160+
with self._client.beta.threads.runs.submit_tool_outputs_stream(
161+
thread_id=run.thread_id,
162+
run_id=run.id,
163+
tool_outputs=tool_outputs,
164+
event_handler=self.clone(),
165+
) as stream:
166+
stream.until_done()

src/git_draft/manager.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ def read_file(self, path: PurePosixPath) -> str:
133133
# Read the file from the index.
134134
return self._repo.git.show(f":{path}")
135135

136-
def write_file(self, path: PurePosixPath, data: str) -> None:
136+
def write_file(self, path: PurePosixPath, contents: str) -> None:
137137
# Update the index without touching the worktree.
138138
# https://stackoverflow.com/a/25352119
139139
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
140-
temp.write(data.encode("utf8"))
140+
temp.write(contents.encode("utf8"))
141141
temp.close()
142142
sha = self._repo.git.hash_object("-w", "--path", path, temp.name)
143143
mode = 644 # TODO: Read from original file if it exists.
@@ -204,18 +204,18 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
204204
raise RuntimeError("Not currently on a draft branch")
205205
if not apply and branch.needs_rebase(self._repo):
206206
raise ValueError("Parent branch has moved, please rebase")
207-
208207
note = branch.init_note
209-
# https://stackoverflow.com/a/15993574
208+
209+
# We do a small dance to move back to the original branch, keeping the
210+
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
211+
# the inspiration.
210212
self._repo.git.checkout("--detach")
211-
if apply:
212-
# We discard index (internal) changes.
213-
self._repo.git.reset(note.origin_branch)
214-
self._repo.git.checkout(note.origin_branch)
215-
else:
216-
self._repo.git.reset("--hard", note.origin_branch)
217-
if note.sync_sha:
218-
self._repo.git.checkout(note.sync_sha, "--", ".")
213+
self._repo.git.reset(
214+
"--mixed" if apply else "--hard", note.origin_branch
215+
)
216+
self._repo.git.checkout(note.origin_branch)
219217

218+
if not apply and note.sync_sha:
219+
self._repo.git.checkout(note.sync_sha, "--", ".")
220220
if delete:
221221
self._repo.git.branch("-D", branch.name)

0 commit comments

Comments
 (0)