diff --git a/README.md b/README.md index 9e0a7a5..3045b47 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,5 @@ template name. Otherwise an inline prompt. * Only include files that the bot has written in draft commits. * Add `--generate` timeout option. -* Add URL and API key to `openai_bot`. Also add a compatibility version which - does not use threads, so that it can be used with tools only. Gemini only - supports the latter. +* Add a compatibility OpenAI bot version which does not use threads, so that it + can be used with tools only. Gemini only supports the latter. diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 3b2bc65..09bb66e 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -36,30 +36,14 @@ def load_bot(config: BotConfig) -> Bot: sys.path.insert(0, config.pythonpath) parts = config.factory.split(":", 1) - if len(parts) == 1: - module = sys.modules[__name__] # Default to this module - symbol = parts[0] - else: - module_name, symbol = parts - module = importlib.import_module(module_name) + if len(parts) != 2: + raise ValueError(f"Invalid bot factory: {config.factory}") + module_name, symbol = parts + module = importlib.import_module(module_name) factory = getattr(module, symbol, None) if not factory: - raise NotImplementedError(f"Unknown factory: {factory}") + raise NotImplementedError(f"Unknown bot factory: {factory}") kwargs = config.config or {} return factory(**kwargs) - - -def openai_bot(**kwargs) -> Bot: - """Instantiate and return an OpenAIBot with provided keyword arguments. - - This function imports the OpenAIBot class from the openai module and - returns an instance configured with the provided arguments. - - Args: - **kwargs: Arbitrary keyword arguments used to configure the bot. - """ - from .openai import OpenAIBot - - return OpenAIBot(**kwargs) diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index 884b731..eafd2f5 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -98,11 +98,14 @@ class Action: class Bot: @classmethod - def state_folder_path(cls) -> Path: + def state_folder_path(cls, ensure_exists=False) -> Path: name = cls.__qualname__ if cls.__module__: name = f"{cls.__module__}.{name}" - return ensure_state_home() / "bots" / name + path = ensure_state_home() / "bots" / name + if ensure_exists: + path.mkdir(parents=True, exist_ok=True) + return path def act(self, prompt: str, toolbox: Toolbox) -> Action: raise NotImplementedError() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index b555e1b..c1eba33 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -1,3 +1,4 @@ +import dataclasses import json import logging import openai @@ -11,6 +12,24 @@ _logger = logging.getLogger(__name__) +def threads_bot( + api_key: str | None = None, base_url: str | None = None +) -> Bot: + client = openai.OpenAI(api_key=api_key, base_url=base_url) + return _ThreadsBot.create(client) + + +# https://aider.chat/docs/more-info.html +# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py +_INSTRUCTIONS = """\ + You are an expert software engineer, who writes correct and concise code. + Use the provided functions to find the filesyou need to answer the query, + read the content of the relevant ones, and save the changes you suggest. + When writing a file, include a summary description of the changes you have + made. +""" + + def _function_tool_param( name: str, description: str, @@ -75,18 +94,21 @@ def _function_tool_param( ] -# https://aider.chat/docs/more-info.html -# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py -_INSTRUCTIONS = """\ - You are an expert software engineer, who writes correct and concise code. - Use the provided functions to find the filesyou need to answer the query, - read the content of the relevant ones, and save the changes you suggest. - When writing a file, include a summary description of the changes you have - made. -""" +@dataclasses.dataclass(frozen=True) +class _AssistantConfig: + instructions: str + model: str + tools: Sequence[openai.types.beta.AssistantToolParam] + +_assistant_config = _AssistantConfig( + instructions=_INSTRUCTIONS, + model="gpt-4o", + tools=_tools, +) -class OpenAIBot(Bot): + +class _ThreadsBot(Bot): """An OpenAI-backed bot See the following links for resources: @@ -97,16 +119,27 @@ class OpenAIBot(Bot): * https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py """ - def __init__(self) -> None: - self._client = openai.OpenAI() + def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: + self._client = client + self._assistant_id = assistant_id + + @classmethod + def create(cls, client: openai.OpenAI) -> Self: + path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID" + config = dataclasses.asdict(_assistant_config) + try: + with open(path) as f: + assistant_id = f.read() + except FileNotFoundError: + assistant = client.beta.assistants.create(**config) + assistant_id = assistant.id + with open(path, "w") as f: + f.write(assistant_id) + else: + client.beta.assistants.update(assistant_id, **config) + return cls(client, assistant_id) def act(self, prompt: str, toolbox: Toolbox) -> Action: - # TODO: Reuse assistant. - assistant = self._client.beta.assistants.create( - instructions=_INSTRUCTIONS, - model="gpt-4o", - tools=_tools, - ) thread = self._client.beta.threads.create() self._client.beta.threads.messages.create( @@ -117,7 +150,7 @@ def act(self, prompt: str, toolbox: Toolbox) -> Action: with self._client.beta.threads.runs.stream( thread_id=thread.id, - assistant_id=assistant.id, + assistant_id=self._assistant_id, event_handler=_EventHandler(self._client, toolbox), ) as stream: stream.until_done() diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index d683fe5..29df65a 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -109,33 +109,6 @@ def create( operation_hook, ) - def _create_branch(self, sync: bool) -> _Branch: - if self._repo.head.is_detached: - raise RuntimeError("No currently active branch") - origin_branch = self._repo.active_branch.name - origin_sha = self._repo.commit().hexsha - - self._repo.git.checkout("--detach") - sync_sha = self._sync() if sync else None - suffix = _Branch.new_suffix() - - with self._store.cursor() as cursor: - cursor.execute( - sql("add-branch"), - { - "suffix": suffix, - "repo_path": self._repo.working_dir, - "origin_branch": origin_branch, - "origin_sha": origin_sha, - "sync_sha": sync_sha, - }, - ) - - branch = _Branch(suffix) - branch_ref = self._repo.create_head(branch.name) - self._repo.git.checkout(branch_ref) - return branch - def generate_draft( self, prompt: str | TemplatedPrompt, @@ -154,8 +127,7 @@ def generate_draft( branch = _Branch.active(self._repo) if branch: _logger.debug("Reusing active branch %s.", branch) - if sync: - self._sync() + self._stage_changes(sync) else: branch = self._create_branch(sync) _logger.debug("Created branch %s.", branch) @@ -219,10 +191,37 @@ def finalize_draft(self, delete=False) -> None: def discard_draft(self, delete=False) -> None: self._exit_draft(False, delete=delete) - def _sync(self) -> str | None: - if not self._repo.is_dirty(untracked_files=True): - return None + def _create_branch(self, sync: bool) -> _Branch: + if self._repo.head.is_detached: + raise RuntimeError("No currently active branch") + origin_branch = self._repo.active_branch.name + origin_sha = self._repo.commit().hexsha + + self._repo.git.checkout("--detach") + sync_sha = self._stage_changes(sync) + suffix = _Branch.new_suffix() + + with self._store.cursor() as cursor: + cursor.execute( + sql("add-branch"), + { + "suffix": suffix, + "repo_path": self._repo.working_dir, + "origin_branch": origin_branch, + "origin_sha": origin_sha, + "sync_sha": sync_sha, + }, + ) + + branch = _Branch(suffix) + branch_ref = self._repo.create_head(branch.name) + self._repo.git.checkout(branch_ref) + return branch + + def _stage_changes(self, sync: bool) -> str | None: self._repo.git.add(all=True) + if not sync or not self._repo.is_dirty(untracked_files=True): + return None ref = self._repo.index.commit("draft! sync") return ref.hexsha diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 74fc8b2..4db31d0 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -71,10 +71,21 @@ def _write(self, name: str, contents="") -> None: def _commits(self) -> Sequence[git.Commit]: return list(self._repo.iter_commits()) + def _list_commit_files(self, ref: str) -> Sequence[str]: + text = self._repo.git.diff_tree( + ref, no_commit_id=True, name_only=True, relative=True + ) + return text.splitlines() + def test_generate_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot()) assert len(self._commits()) == 2 + def test_generate_stages_worktree(self) -> None: + self._write("marker", "hi") + self._drafter.generate_draft("hello", FakeBot()) + assert "marker" in self._list_commit_files("HEAD") + def test_generate_then_discard_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot()) self._drafter.discard_draft()