Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@
template name. Otherwise an inline prompt.
* Only include files that the bot has written in draft commits.
* Add `--generate` timeout option.
* Add URL and API key to `openai_bot`. Also add a compatibility version which
does not use threads, so that it can be used with tools only. Gemini only
supports the latter.
* Add a compatibility OpenAI bot version which does not use threads, so that it
can be used with tools only. Gemini only supports the latter.
26 changes: 5 additions & 21 deletions src/git_draft/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,14 @@ def load_bot(config: BotConfig) -> Bot:
sys.path.insert(0, config.pythonpath)

parts = config.factory.split(":", 1)
if len(parts) == 1:
module = sys.modules[__name__] # Default to this module
symbol = parts[0]
else:
module_name, symbol = parts
module = importlib.import_module(module_name)
if len(parts) != 2:
raise ValueError(f"Invalid bot factory: {config.factory}")
module_name, symbol = parts
module = importlib.import_module(module_name)

factory = getattr(module, symbol, None)
if not factory:
raise NotImplementedError(f"Unknown factory: {factory}")
raise NotImplementedError(f"Unknown bot factory: {factory}")

kwargs = config.config or {}
return factory(**kwargs)


def openai_bot(**kwargs) -> Bot:
"""Instantiate and return an OpenAIBot with provided keyword arguments.

This function imports the OpenAIBot class from the openai module and
returns an instance configured with the provided arguments.

Args:
**kwargs: Arbitrary keyword arguments used to configure the bot.
"""
from .openai import OpenAIBot

return OpenAIBot(**kwargs)
7 changes: 5 additions & 2 deletions src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ class Action:

class Bot:
@classmethod
def state_folder_path(cls) -> Path:
def state_folder_path(cls, ensure_exists=False) -> Path:
name = cls.__qualname__
if cls.__module__:
name = f"{cls.__module__}.{name}"
return ensure_state_home() / "bots" / name
path = ensure_state_home() / "bots" / name
if ensure_exists:
path.mkdir(parents=True, exist_ok=True)
return path

def act(self, prompt: str, toolbox: Toolbox) -> Action:
raise NotImplementedError()
71 changes: 52 additions & 19 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import json
import logging
import openai
Expand All @@ -11,6 +12,24 @@
_logger = logging.getLogger(__name__)


def threads_bot(
api_key: str | None = None, base_url: str | None = None
) -> Bot:
client = openai.OpenAI(api_key=api_key, base_url=base_url)
return _ThreadsBot.create(client)


# https://aider.chat/docs/more-info.html
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
_INSTRUCTIONS = """\
You are an expert software engineer, who writes correct and concise code.
Use the provided functions to find the filesyou need to answer the query,
read the content of the relevant ones, and save the changes you suggest.
When writing a file, include a summary description of the changes you have
made.
"""


def _function_tool_param(
name: str,
description: str,
Expand Down Expand Up @@ -75,18 +94,21 @@ def _function_tool_param(
]


# https://aider.chat/docs/more-info.html
# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
_INSTRUCTIONS = """\
You are an expert software engineer, who writes correct and concise code.
Use the provided functions to find the filesyou need to answer the query,
read the content of the relevant ones, and save the changes you suggest.
When writing a file, include a summary description of the changes you have
made.
"""
@dataclasses.dataclass(frozen=True)
class _AssistantConfig:
instructions: str
model: str
tools: Sequence[openai.types.beta.AssistantToolParam]


_assistant_config = _AssistantConfig(
instructions=_INSTRUCTIONS,
model="gpt-4o",
tools=_tools,
)

class OpenAIBot(Bot):

class _ThreadsBot(Bot):
"""An OpenAI-backed bot

See the following links for resources:
Expand All @@ -97,16 +119,27 @@ class OpenAIBot(Bot):
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
"""

def __init__(self) -> None:
self._client = openai.OpenAI()
def __init__(self, client: openai.OpenAI, assistant_id: str) -> None:
self._client = client
self._assistant_id = assistant_id

@classmethod
def create(cls, client: openai.OpenAI) -> Self:
path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID"
config = dataclasses.asdict(_assistant_config)
try:
with open(path) as f:
assistant_id = f.read()
except FileNotFoundError:
assistant = client.beta.assistants.create(**config)
assistant_id = assistant.id
with open(path, "w") as f:
f.write(assistant_id)
else:
client.beta.assistants.update(assistant_id, **config)
return cls(client, assistant_id)

def act(self, prompt: str, toolbox: Toolbox) -> Action:
# TODO: Reuse assistant.
assistant = self._client.beta.assistants.create(
instructions=_INSTRUCTIONS,
model="gpt-4o",
tools=_tools,
)
thread = self._client.beta.threads.create()

self._client.beta.threads.messages.create(
Expand All @@ -117,7 +150,7 @@ def act(self, prompt: str, toolbox: Toolbox) -> Action:

with self._client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant.id,
assistant_id=self._assistant_id,
event_handler=_EventHandler(self._client, toolbox),
) as stream:
stream.until_done()
Expand Down
63 changes: 31 additions & 32 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,33 +109,6 @@ def create(
operation_hook,
)

def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
raise RuntimeError("No currently active branch")
origin_branch = self._repo.active_branch.name
origin_sha = self._repo.commit().hexsha

self._repo.git.checkout("--detach")
sync_sha = self._sync() if sync else None
suffix = _Branch.new_suffix()

with self._store.cursor() as cursor:
cursor.execute(
sql("add-branch"),
{
"suffix": suffix,
"repo_path": self._repo.working_dir,
"origin_branch": origin_branch,
"origin_sha": origin_sha,
"sync_sha": sync_sha,
},
)

branch = _Branch(suffix)
branch_ref = self._repo.create_head(branch.name)
self._repo.git.checkout(branch_ref)
return branch

def generate_draft(
self,
prompt: str | TemplatedPrompt,
Expand All @@ -154,8 +127,7 @@ def generate_draft(
branch = _Branch.active(self._repo)
if branch:
_logger.debug("Reusing active branch %s.", branch)
if sync:
self._sync()
self._stage_changes(sync)
else:
branch = self._create_branch(sync)
_logger.debug("Created branch %s.", branch)
Expand Down Expand Up @@ -219,10 +191,37 @@ def finalize_draft(self, delete=False) -> None:
def discard_draft(self, delete=False) -> None:
self._exit_draft(False, delete=delete)

def _sync(self) -> str | None:
if not self._repo.is_dirty(untracked_files=True):
return None
def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
raise RuntimeError("No currently active branch")
origin_branch = self._repo.active_branch.name
origin_sha = self._repo.commit().hexsha

self._repo.git.checkout("--detach")
sync_sha = self._stage_changes(sync)
suffix = _Branch.new_suffix()

with self._store.cursor() as cursor:
cursor.execute(
sql("add-branch"),
{
"suffix": suffix,
"repo_path": self._repo.working_dir,
"origin_branch": origin_branch,
"origin_sha": origin_sha,
"sync_sha": sync_sha,
},
)

branch = _Branch(suffix)
branch_ref = self._repo.create_head(branch.name)
self._repo.git.checkout(branch_ref)
return branch

def _stage_changes(self, sync: bool) -> str | None:
self._repo.git.add(all=True)
if not sync or not self._repo.is_dirty(untracked_files=True):
return None
ref = self._repo.index.commit("draft! sync")
return ref.hexsha

Expand Down
11 changes: 11 additions & 0 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,21 @@ def _write(self, name: str, contents="") -> None:
def _commits(self) -> Sequence[git.Commit]:
return list(self._repo.iter_commits())

def _list_commit_files(self, ref: str) -> Sequence[str]:
text = self._repo.git.diff_tree(
ref, no_commit_id=True, name_only=True, relative=True
)
return text.splitlines()

def test_generate_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
assert len(self._commits()) == 2

def test_generate_stages_worktree(self) -> None:
self._write("marker", "hi")
self._drafter.generate_draft("hello", FakeBot())
assert "marker" in self._list_commit_files("HEAD")

def test_generate_then_discard_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
self._drafter.discard_draft()
Expand Down