Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/git_draft/assistants/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Toolbox(Protocol):
def list_files(self) -> Sequence[PurePosixPath]: ...
def read_file(self, path: PurePosixPath) -> str: ...
def write_file(self, path: PurePosixPath, data: str) -> None: ...
def write_file(self, path: PurePosixPath, contents: str) -> None: ...


@dataclasses.dataclass(frozen=True)
Expand Down
187 changes: 131 additions & 56 deletions src/git_draft/assistants/openai.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,85 @@
import json
import logging
import openai
from pathlib import PurePosixPath
import textwrap
from typing import Any, Mapping, Self, Sequence, override

from .common import Assistant, Session, Toolbox


# https://aider.chat/docs/more-info.html
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
_INSTRUCTIONS = """\
You are an expert software engineer, who writes correct and concise code.
"""
_logger = logging.getLogger(__name__)

_tools = [ # TODO
{

def _function_tool_param(
name: str,
description: str,
inputs: Mapping[str, Any] | None = None,
required_inputs: Sequence[str] | None = None,
) -> openai.types.beta.FunctionToolParam:
return {
"type": "function",
"function": {
"name": "read_file",
"description": "Get a file's contents",
"name": name,
"description": textwrap.dedent(description),
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the file to be read",
},
},
"required": ["path"],
"additionalProperties": False,
"properties": inputs or {},
"required": required_inputs or [],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Update a file's contents",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the file to be updated",
},
"contents": {
"type": "string",
"description": "New contents of the file",
},
},
"required": ["path", "contents"],
}


_tools = [
_function_tool_param(
name="list_files",
description="List all available files",
),
_function_tool_param(
name="read_file",
description="Get a file's contents",
inputs={
"path": {
"type": "string",
"description": "Path of the file to be read",
},
},
required_inputs=["path"],
),
_function_tool_param(
name="write_file",
description="""\
Set a file's contents

The file will be created if it does not already exist.
""",
inputs={
"path": {
"type": "string",
"description": "Path of the file to be updated",
},
"contents": {
"type": "string",
"description": "New contents of the file",
},
},
},
required_inputs=["path", "contents"],
),
]


# https://aider.chat/docs/more-info.html
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
_INSTRUCTIONS = """\
You are an expert software engineer, who writes correct and concise code.
Use the provided functions to find the filesyou need to answer the query,
read the content of the relevant ones, and save the changes you suggest.
"""


class OpenAIAssistant(Assistant):
"""An OpenAI-backed assistant

Expand All @@ -66,26 +94,73 @@ def __init__(self) -> None:
self._client = openai.OpenAI()

def run(self, prompt: str, toolbox: Toolbox) -> Session:
# TODO: Switch to the thread run API, using tools to leverage toolbox
# methods.
# assistant = client.beta.assistants.create(
# instructions=_INSTRUCTIONS,
# model="gpt-4o",
# tools=_tools,
# )
# thread = client.beta.threads.create()
# message = client.beta.threads.messages.create(
# thread_id=thread.id,
# role="user",
# content="What's the weather in San Francisco today and the likelihood it'll rain?",
# )
completion = self._client.chat.completions.create(
messages=[
{"role": "system", "content": _INSTRUCTIONS},
{"role": "user", "content": prompt},
],
# TODO: Reuse assistant.
assistant = self._client.beta.assistants.create(
instructions=_INSTRUCTIONS,
model="gpt-4o",
tools=_tools,
)
thread = self._client.beta.threads.create()

message = self._client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=prompt,
)
content = completion.choices[0].message.content or ""
toolbox.write_file(PurePosixPath(f"{completion.id}.txt"), content)
print(message)

with self._client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=_EventHandler(self._client, toolbox),
) as stream:
stream.until_done()

return Session(0)


class _EventHandler(openai.AssistantEventHandler):
def __init__(self, client: openai.Client, toolbox: Toolbox) -> None:
super().__init__()
self._client = client
self._toolbox = toolbox

def clone(self) -> Self:
return self.__class__(self._client, self._toolbox)

@override
def on_event(self, event: Any) -> None:
_logger.debug("Event: %s", event)
if event.event == "thread.run.requires_action":
run_id = event.data.id # Retrieve the run ID from the event data
self._handle_action(run_id, event.data)
# TODO: Handle (log?) other events.

def _handle_action(self, run_id: str, data: Any) -> None:
tool_outputs = list[Any]()
for tool in data.required_action.submit_tool_outputs.tool_calls:
name = tool.function.name
inputs = json.loads(tool.function.arguments)
_logger.info("Requested tool: %s", tool)
if name == "read_file":
path = PurePosixPath(inputs["path"])
output = self._toolbox.read_file(path)
elif name == "write_file":
path = PurePosixPath(inputs["path"])
contents = inputs["contents"]
self._toolbox.write_file(path, contents)
output = "OK"
elif name == "list_files":
assert not inputs
output = "\n".join(str(p) for p in self._toolbox.list_files())
tool_outputs.append({"tool_call_id": tool.id, "output": output})

run = self.current_run
assert run, "No ongoing run"
with self._client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=tool_outputs,
event_handler=self.clone(),
) as stream:
stream.until_done()
24 changes: 12 additions & 12 deletions src/git_draft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def read_file(self, path: PurePosixPath) -> str:
# Read the file from the index.
return self._repo.git.show(f":{path}")

def write_file(self, path: PurePosixPath, data: str) -> None:
def write_file(self, path: PurePosixPath, contents: str) -> None:
# Update the index without touching the worktree.
# https://stackoverflow.com/a/25352119
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
temp.write(data.encode("utf8"))
temp.write(contents.encode("utf8"))
temp.close()
sha = self._repo.git.hash_object("-w", "--path", path, temp.name)
mode = 644 # TODO: Read from original file if it exists.
Expand Down Expand Up @@ -204,18 +204,18 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
raise RuntimeError("Not currently on a draft branch")
if not apply and branch.needs_rebase(self._repo):
raise ValueError("Parent branch has moved, please rebase")

note = branch.init_note
# https://stackoverflow.com/a/15993574

# We do a small dance to move back to the original branch, keeping the
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
# the inspiration.
self._repo.git.checkout("--detach")
if apply:
# We discard index (internal) changes.
self._repo.git.reset(note.origin_branch)
self._repo.git.checkout(note.origin_branch)
else:
self._repo.git.reset("--hard", note.origin_branch)
if note.sync_sha:
self._repo.git.checkout(note.sync_sha, "--", ".")
self._repo.git.reset(
"--mixed" if apply else "--hard", note.origin_branch
)
self._repo.git.checkout(note.origin_branch)

if not apply and note.sync_sha:
self._repo.git.checkout(note.sync_sha, "--", ".")
if delete:
self._repo.git.branch("-D", branch.name)