Skip to content

Commit 2ef9e48

Browse files
authored
feat: reuse OpenAI assistant (#35)
1 parent 8078161 commit 2ef9e48

File tree

6 files changed

+106
-77
lines changed

6 files changed

+106
-77
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@
1616
template name. Otherwise an inline prompt.
1717
* Only include files that the bot has written in draft commits.
1818
* Add `--generate` timeout option.
19-
* Add URL and API key to `openai_bot`. Also add a compatibility version which
20-
does not use threads, so that it can be used with tools only. Gemini only
21-
supports the latter.
19+
* Add a compatibility OpenAI bot version which does not use threads, so that it
20+
can be used with tools only. Gemini only supports the latter.

src/git_draft/bots/__init__.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,14 @@ def load_bot(config: BotConfig) -> Bot:
3636
sys.path.insert(0, config.pythonpath)
3737

3838
parts = config.factory.split(":", 1)
39-
if len(parts) == 1:
40-
module = sys.modules[__name__] # Default to this module
41-
symbol = parts[0]
42-
else:
43-
module_name, symbol = parts
44-
module = importlib.import_module(module_name)
39+
if len(parts) != 2:
40+
raise ValueError(f"Invalid bot factory: {config.factory}")
41+
module_name, symbol = parts
42+
module = importlib.import_module(module_name)
4543

4644
factory = getattr(module, symbol, None)
4745
if not factory:
48-
raise NotImplementedError(f"Unknown factory: {factory}")
46+
raise NotImplementedError(f"Unknown bot factory: {factory}")
4947

5048
kwargs = config.config or {}
5149
return factory(**kwargs)
52-
53-
54-
def openai_bot(**kwargs) -> Bot:
55-
"""Instantiate and return an OpenAIBot with provided keyword arguments.
56-
57-
This function imports the OpenAIBot class from the openai module and
58-
returns an instance configured with the provided arguments.
59-
60-
Args:
61-
**kwargs: Arbitrary keyword arguments used to configure the bot.
62-
"""
63-
from .openai import OpenAIBot
64-
65-
return OpenAIBot(**kwargs)

src/git_draft/bots/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ class Action:
9898

9999
class Bot:
100100
@classmethod
101-
def state_folder_path(cls) -> Path:
101+
def state_folder_path(cls, ensure_exists=False) -> Path:
102102
name = cls.__qualname__
103103
if cls.__module__:
104104
name = f"{cls.__module__}.{name}"
105-
return ensure_state_home() / "bots" / name
105+
path = ensure_state_home() / "bots" / name
106+
if ensure_exists:
107+
path.mkdir(parents=True, exist_ok=True)
108+
return path
106109

107110
def act(self, prompt: str, toolbox: Toolbox) -> Action:
108111
raise NotImplementedError()

src/git_draft/bots/openai.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import json
23
import logging
34
import openai
@@ -11,6 +12,24 @@
1112
_logger = logging.getLogger(__name__)
1213

1314

15+
def threads_bot(
16+
api_key: str | None = None, base_url: str | None = None
17+
) -> Bot:
18+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
19+
return _ThreadsBot.create(client)
20+
21+
22+
# https://aider.chat/docs/more-info.html
23+
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
24+
_INSTRUCTIONS = """\
25+
You are an expert software engineer, who writes correct and concise code.
26+
Use the provided functions to find the filesyou need to answer the query,
27+
read the content of the relevant ones, and save the changes you suggest.
28+
When writing a file, include a summary description of the changes you have
29+
made.
30+
"""
31+
32+
1433
def _function_tool_param(
1534
name: str,
1635
description: str,
@@ -75,18 +94,21 @@ def _function_tool_param(
7594
]
7695

7796

78-
# https://aider.chat/docs/more-info.html
79-
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
80-
_INSTRUCTIONS = """\
81-
You are an expert software engineer, who writes correct and concise code.
82-
Use the provided functions to find the filesyou need to answer the query,
83-
read the content of the relevant ones, and save the changes you suggest.
84-
When writing a file, include a summary description of the changes you have
85-
made.
86-
"""
97+
@dataclasses.dataclass(frozen=True)
98+
class _AssistantConfig:
99+
instructions: str
100+
model: str
101+
tools: Sequence[openai.types.beta.AssistantToolParam]
102+
87103

104+
_assistant_config = _AssistantConfig(
105+
instructions=_INSTRUCTIONS,
106+
model="gpt-4o",
107+
tools=_tools,
108+
)
88109

89-
class OpenAIBot(Bot):
110+
111+
class _ThreadsBot(Bot):
90112
"""An OpenAI-backed bot
91113
92114
See the following links for resources:
@@ -97,16 +119,27 @@ class OpenAIBot(Bot):
97119
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
98120
"""
99121

100-
def __init__(self) -> None:
101-
self._client = openai.OpenAI()
122+
def __init__(self, client: openai.OpenAI, assistant_id: str) -> None:
123+
self._client = client
124+
self._assistant_id = assistant_id
125+
126+
@classmethod
127+
def create(cls, client: openai.OpenAI) -> Self:
128+
path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID"
129+
config = dataclasses.asdict(_assistant_config)
130+
try:
131+
with open(path) as f:
132+
assistant_id = f.read()
133+
except FileNotFoundError:
134+
assistant = client.beta.assistants.create(**config)
135+
assistant_id = assistant.id
136+
with open(path, "w") as f:
137+
f.write(assistant_id)
138+
else:
139+
client.beta.assistants.update(assistant_id, **config)
140+
return cls(client, assistant_id)
102141

103142
def act(self, prompt: str, toolbox: Toolbox) -> Action:
104-
# TODO: Reuse assistant.
105-
assistant = self._client.beta.assistants.create(
106-
instructions=_INSTRUCTIONS,
107-
model="gpt-4o",
108-
tools=_tools,
109-
)
110143
thread = self._client.beta.threads.create()
111144

112145
self._client.beta.threads.messages.create(
@@ -117,7 +150,7 @@ def act(self, prompt: str, toolbox: Toolbox) -> Action:
117150

118151
with self._client.beta.threads.runs.stream(
119152
thread_id=thread.id,
120-
assistant_id=assistant.id,
153+
assistant_id=self._assistant_id,
121154
event_handler=_EventHandler(self._client, toolbox),
122155
) as stream:
123156
stream.until_done()

src/git_draft/drafter.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -109,33 +109,6 @@ def create(
109109
operation_hook,
110110
)
111111

112-
def _create_branch(self, sync: bool) -> _Branch:
113-
if self._repo.head.is_detached:
114-
raise RuntimeError("No currently active branch")
115-
origin_branch = self._repo.active_branch.name
116-
origin_sha = self._repo.commit().hexsha
117-
118-
self._repo.git.checkout("--detach")
119-
sync_sha = self._sync() if sync else None
120-
suffix = _Branch.new_suffix()
121-
122-
with self._store.cursor() as cursor:
123-
cursor.execute(
124-
sql("add-branch"),
125-
{
126-
"suffix": suffix,
127-
"repo_path": self._repo.working_dir,
128-
"origin_branch": origin_branch,
129-
"origin_sha": origin_sha,
130-
"sync_sha": sync_sha,
131-
},
132-
)
133-
134-
branch = _Branch(suffix)
135-
branch_ref = self._repo.create_head(branch.name)
136-
self._repo.git.checkout(branch_ref)
137-
return branch
138-
139112
def generate_draft(
140113
self,
141114
prompt: str | TemplatedPrompt,
@@ -154,8 +127,7 @@ def generate_draft(
154127
branch = _Branch.active(self._repo)
155128
if branch:
156129
_logger.debug("Reusing active branch %s.", branch)
157-
if sync:
158-
self._sync()
130+
self._stage_changes(sync)
159131
else:
160132
branch = self._create_branch(sync)
161133
_logger.debug("Created branch %s.", branch)
@@ -219,10 +191,37 @@ def finalize_draft(self, delete=False) -> None:
219191
def discard_draft(self, delete=False) -> None:
220192
self._exit_draft(False, delete=delete)
221193

222-
def _sync(self) -> str | None:
223-
if not self._repo.is_dirty(untracked_files=True):
224-
return None
194+
def _create_branch(self, sync: bool) -> _Branch:
195+
if self._repo.head.is_detached:
196+
raise RuntimeError("No currently active branch")
197+
origin_branch = self._repo.active_branch.name
198+
origin_sha = self._repo.commit().hexsha
199+
200+
self._repo.git.checkout("--detach")
201+
sync_sha = self._stage_changes(sync)
202+
suffix = _Branch.new_suffix()
203+
204+
with self._store.cursor() as cursor:
205+
cursor.execute(
206+
sql("add-branch"),
207+
{
208+
"suffix": suffix,
209+
"repo_path": self._repo.working_dir,
210+
"origin_branch": origin_branch,
211+
"origin_sha": origin_sha,
212+
"sync_sha": sync_sha,
213+
},
214+
)
215+
216+
branch = _Branch(suffix)
217+
branch_ref = self._repo.create_head(branch.name)
218+
self._repo.git.checkout(branch_ref)
219+
return branch
220+
221+
def _stage_changes(self, sync: bool) -> str | None:
225222
self._repo.git.add(all=True)
223+
if not sync or not self._repo.is_dirty(untracked_files=True):
224+
return None
226225
ref = self._repo.index.commit("draft! sync")
227226
return ref.hexsha
228227

tests/git_draft/drafter_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,21 @@ def _write(self, name: str, contents="") -> None:
7171
def _commits(self) -> Sequence[git.Commit]:
7272
return list(self._repo.iter_commits())
7373

74+
def _list_commit_files(self, ref: str) -> Sequence[str]:
75+
text = self._repo.git.diff_tree(
76+
ref, no_commit_id=True, name_only=True, relative=True
77+
)
78+
return text.splitlines()
79+
7480
def test_generate_draft(self) -> None:
7581
self._drafter.generate_draft("hello", FakeBot())
7682
assert len(self._commits()) == 2
7783

84+
def test_generate_stages_worktree(self) -> None:
85+
self._write("marker", "hi")
86+
self._drafter.generate_draft("hello", FakeBot())
87+
assert "marker" in self._list_commit_files("HEAD")
88+
7889
def test_generate_then_discard_draft(self) -> None:
7990
self._drafter.generate_draft("hello", FakeBot())
8091
self._drafter.discard_draft()

0 commit comments

Comments
 (0)